From 32eea2f1a50430786f6de85a639a63d75ce30511 Mon Sep 17 00:00:00 2001 From: vincent-maillou Date: Sat, 5 Apr 2025 09:27:04 +0200 Subject: [PATCH 001/157] added timers for comm --- src/serinv/wrappers/pddbtasc.py | 9 +++++++++ src/serinv/wrappers/pddbtsc.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/src/serinv/wrappers/pddbtasc.py b/src/serinv/wrappers/pddbtasc.py index da2fcff4..7cb0374c 100644 --- a/src/serinv/wrappers/pddbtasc.py +++ b/src/serinv/wrappers/pddbtasc.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -202,12 +204,17 @@ def pddbtasc( _rhs=ddbtars.get("_rhs", None), ) + MPI.COMM_WORLD.Barrier() + tic = time.perf_counter() aggregate_ddbtars( ddbtars=ddbtars, quadratic=quadratic, comm=comm, strategy=strategy, ) + MPI.COMM_WORLD.Barrier() + toc = time.perf_counter() + elapsed = toc - tic ddbtars["A_arrow_tip_block"][:] += A_arrow_tip_initial if quadratic: @@ -226,3 +233,5 @@ def pddbtasc( ) comm.Barrier() + + return elapsed diff --git a/src/serinv/wrappers/pddbtsc.py b/src/serinv/wrappers/pddbtsc.py index fc0a3765..b846bb07 100644 --- a/src/serinv/wrappers/pddbtsc.py +++ b/src/serinv/wrappers/pddbtsc.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -163,12 +165,17 @@ def pddbtsc( _rhs=ddbtrs.get("_rhs", None), ) + MPI.COMM_WORLD.Barrier() + tic = time.perf_counter() aggregate_ddbtrs( ddbtrs=ddbtrs, quadratic=quadratic, comm=comm, strategy=strategy, ) + MPI.COMM_WORLD.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Perform Schur complement on the reduced system ddbtsc( @@ -180,3 +187,5 @@ def pddbtsc( ) comm.Barrier() + + return elapsed \ No newline at end of file From 2dc799bc8893a28eb58e5f22020ceecfc114fdaa Mon Sep 17 00:00:00 2001 From: vincent-maillou Date: Sat, 5 Apr 2025 18:13:10 +0200 Subject: [PATCH 002/157] added nccl --- src/serinv/__init__.py | 6 +- src/serinv/wrappers/ddbtars.py | 109 +++++++++++++++++-------- src/serinv/wrappers/ddbtrs.py | 75 ++++++++++++----- src/serinv/wrappers/pddbtasc.py | 8 +- src/serinv/wrappers/pddbtasci.py | 3 + src/serinv/wrappers/pddbtsc.py | 6 ++ src/serinv/wrappers/pddbtsci.py | 6 +- src/serinv/wrappers/pobtars.py | 135 ++++++++++++++++++++++++------- src/serinv/wrappers/pobtrs.py | 115 +++++++++++++++++++++----- src/serinv/wrappers/ppobtaf.py | 16 ++++ src/serinv/wrappers/ppobtas.py | 16 ++++ src/serinv/wrappers/ppobtasi.py | 4 +- src/serinv/wrappers/ppobtf.py | 16 ++++ src/serinv/wrappers/ppobts.py | 17 ++++ src/serinv/wrappers/ppobtsi.py | 4 +- 15 files changed, 426 insertions(+), 110 deletions(-) diff --git a/src/serinv/__init__.py b/src/serinv/__init__.py index ca7f21dd..bad9c860 100644 --- a/src/serinv/__init__.py +++ b/src/serinv/__init__.py @@ -163,7 +163,7 @@ def _use_nccl(comm): return False -def _get_nccl_parameters(arr, comm, op: str): +def _get_nccl_parameters(arr, comm, rank, op: str): """Get the NCCL parameters for the given operation.""" if np.iscomplexobj(arr): factor = 2 @@ -172,8 +172,8 @@ def _get_nccl_parameters(arr, comm, op: str): if backend_flags["nccl_avail"]: if op == "allgather": - count = (arr.size // comm.size) * factor - displacement = count * comm.rank * arr.dtype.itemsize + count = (arr.size // comm.size()) * factor + displacement = count * rank * (arr.dtype.itemsize // factor) elif op == "allreduce": count = arr.size * factor displacement = 0 diff --git a/src/serinv/wrappers/ddbtars.py b/src/serinv/wrappers/ddbtars.py index 46be5972..722582a6 100644 --- a/src/serinv/wrappers/ddbtars.py +++ b/src/serinv/wrappers/ddbtars.py @@ -15,9 +15,6 @@ import cupyx as cpx import cupy as cp - if backend_flags["nccl_avail"]: - from cupy.cuda import nccl - def allocate_ddbtars( A_diagonal_blocks: ArrayLike, @@ -30,7 +27,14 @@ def allocate_ddbtars( comm: MPI.Comm, strategy: str = "allgather", quadratic: bool = False, + nccl_comm: object = None, ) -> dict: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -80,7 +84,7 @@ def allocate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -144,7 +148,7 @@ def allocate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -224,8 +228,15 @@ def map_ddbtasc_to_ddbtars( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str, + nccl_comm: object = None, **kwargs, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -360,7 +371,14 @@ def aggregate_ddbtars( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -455,7 +473,7 @@ def aggregate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. if comm_rank == 0: @@ -564,11 +582,12 @@ def aggregate_ddbtars( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -576,9 +595,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -586,9 +605,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -596,9 +615,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_arrow_blocks_comm.data.ptr, count=count, @@ -606,9 +625,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_arrow_blocks_comm.data.ptr, count=count, @@ -616,17 +635,18 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_A_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_A_arrow_tip_block_comm.data.ptr, recvbuf=_A_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -662,11 +682,12 @@ def aggregate_ddbtars( ddbtars["A_upper_arrow_blocks"] = _A_upper_arrow_blocks[1:] if quadratic: - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_diagonal_blocks_comm.data.ptr, count=count, @@ -674,9 +695,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -684,9 +705,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -694,9 +715,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_arrow_blocks_comm.data.ptr, count=count, @@ -704,9 +725,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_arrow_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_arrow_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_arrow_blocks_comm.data.ptr, count=count, @@ -714,17 +735,18 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_B_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_B_arrow_tip_block_comm.data.ptr, recvbuf=_B_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_diagonal_blocks_comm, @@ -763,7 +785,7 @@ def aggregate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -787,8 +809,18 @@ def scatter_ddbtars( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() + _A_diagonal_blocks: ArrayLike = ddbtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = ddbtars.get("A_lower_diagonal_blocks", None) _A_upper_diagonal_blocks: ArrayLike = ddbtars.get("A_upper_diagonal_blocks", None) @@ -857,8 +889,15 @@ def map_ddbtars_to_ddbtasci( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/ddbtrs.py b/src/serinv/wrappers/ddbtrs.py index 15393ae9..937d8def 100644 --- a/src/serinv/wrappers/ddbtrs.py +++ b/src/serinv/wrappers/ddbtrs.py @@ -24,7 +24,14 @@ def allocate_ddbtrs( comm: MPI.Comm, strategy: str = "allgather", quadratic: bool = False, + nccl_comm: object = None, ) -> dict: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -57,7 +64,7 @@ def allocate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -98,7 +105,7 @@ def allocate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -150,8 +157,15 @@ def map_ddbtsc_to_ddbtrs( _A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -241,7 +255,14 @@ def aggregate_ddbtrs( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -304,7 +325,7 @@ def aggregate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. if comm_rank == 0: @@ -379,11 +400,12 @@ def aggregate_ddbtrs( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -391,9 +413,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -401,9 +423,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -411,6 +433,7 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -433,11 +456,12 @@ def aggregate_ddbtrs( ddbtrs["A_upper_diagonal_blocks"] = _A_upper_diagonal_blocks[1:-2] if quadratic: - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_diagonal_blocks_comm.data.ptr, count=count, @@ -445,9 +469,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -455,9 +479,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -465,6 +489,7 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_diagonal_blocks_comm, @@ -492,7 +517,7 @@ def aggregate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -510,8 +535,15 @@ def scatter_ddbtrs( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -565,8 +597,15 @@ def map_ddbtrs_to_ddbtsci( _A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/pddbtasc.py b/src/serinv/wrappers/pddbtasc.py index 7cb0374c..6c756c92 100644 --- a/src/serinv/wrappers/pddbtasc.py +++ b/src/serinv/wrappers/pddbtasc.py @@ -24,6 +24,7 @@ def pddbtasc( A_upper_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel Schur-complement of a block tridiagonal matrix. @@ -117,9 +118,10 @@ def pddbtasc( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + rhs: dict = kwargs.get("rhs", None) quadratic: bool = kwargs.get("quadratic", False) - buffers: dict = kwargs.get("buffers", None) ddbtars: dict = kwargs.get("ddbtars", None) strategy: str = kwargs.get("strategy", "allgather") @@ -202,6 +204,7 @@ def pddbtasc( quadratic=quadratic, buffers=buffers, _rhs=ddbtars.get("_rhs", None), + nccl_comm=nccl_comm, ) MPI.COMM_WORLD.Barrier() @@ -211,7 +214,10 @@ def pddbtasc( quadratic=quadratic, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() MPI.COMM_WORLD.Barrier() toc = time.perf_counter() elapsed = toc - tic diff --git a/src/serinv/wrappers/pddbtasci.py b/src/serinv/wrappers/pddbtasci.py index f86b6a9c..f235d76b 100644 --- a/src/serinv/wrappers/pddbtasci.py +++ b/src/serinv/wrappers/pddbtasci.py @@ -22,6 +22,7 @@ def pddbtasci( A_upper_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel selected-inversion of the Schur-complement of a block tridiagonal matrix. @@ -159,6 +160,7 @@ def pddbtasci( ddbtars=ddbtars, comm=comm, quadratic=quadratic, + nccl_comm=nccl_comm, ) map_ddbtars_to_ddbtasci( @@ -179,6 +181,7 @@ def pddbtasci( quadratic=quadratic, buffers=buffers, _rhs=ddbtars.get("_rhs", None), + nccl_comm=nccl_comm, ) # Perform distributed SCI diff --git a/src/serinv/wrappers/pddbtsc.py b/src/serinv/wrappers/pddbtsc.py index b846bb07..c1ec3b72 100644 --- a/src/serinv/wrappers/pddbtsc.py +++ b/src/serinv/wrappers/pddbtsc.py @@ -21,6 +21,7 @@ def pddbtsc( A_lower_diagonal_blocks: ArrayLike, A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel Schur-complement of a block tridiagonal matrix. @@ -84,6 +85,7 @@ def pddbtsc( - _B_upper_diagonal_blocks : ArrayLike The upper diagonal blocks of the reduced system. """ + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -163,6 +165,7 @@ def pddbtsc( quadratic=quadratic, buffers=buffers, _rhs=ddbtrs.get("_rhs", None), + nccl_comm=nccl_comm, ) MPI.COMM_WORLD.Barrier() @@ -172,7 +175,10 @@ def pddbtsc( quadratic=quadratic, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() MPI.COMM_WORLD.Barrier() toc = time.perf_counter() elapsed = toc - tic diff --git a/src/serinv/wrappers/pddbtsci.py b/src/serinv/wrappers/pddbtsci.py index 144054f8..63d94c84 100644 --- a/src/serinv/wrappers/pddbtsci.py +++ b/src/serinv/wrappers/pddbtsci.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import ddbtsci @@ -19,6 +18,7 @@ def pddbtsci( A_lower_diagonal_blocks: ArrayLike, A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel selected-inversion of the Schur-complement of a block tridiagonal matrix. @@ -89,8 +89,6 @@ def pddbtsci( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") - xp, _ = _get_module_from_array(arr=A_diagonal_blocks) - rhs: dict = kwargs.get("rhs", None) quadratic: bool = kwargs.get("quadratic", False) buffers: dict = kwargs.get("buffers", None) @@ -125,6 +123,7 @@ def pddbtsci( ddbtrs=ddbtrs, comm=comm, quadratic=quadratic, + nccl_comm=nccl_comm, ) map_ddbtrs_to_ddbtsci( @@ -139,6 +138,7 @@ def pddbtsci( quadratic=quadratic, buffers=buffers, _rhs=ddbtrs.get("_rhs", None), + nccl_comm=nccl_comm, ) # Perform distributed SCI diff --git a/src/serinv/wrappers/pobtars.py b/src/serinv/wrappers/pobtars.py index 98cfdf46..dffa86e8 100644 --- a/src/serinv/wrappers/pobtars.py +++ b/src/serinv/wrappers/pobtars.py @@ -15,9 +15,6 @@ import cupyx as cpx import cupy as cp - if backend_flags["nccl_avail"]: - from cupy.cuda import nccl - def allocate_pobtars( A_diagonal_blocks: ArrayLike, @@ -29,6 +26,7 @@ def allocate_pobtars( B: ArrayLike = None, device_streaming: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Allocate the buffers necessary for the reduced system of the PPOBTARX algorithms. @@ -56,6 +54,12 @@ def allocate_pobtars( pobtars : dict Dictionary containing the reduced system arrays. """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -106,7 +110,7 @@ def allocate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): _A_diagonal_blocks_comm = cpx.empty_like_pinned(_A_diagonal_blocks) _A_lower_diagonal_blocks_comm = cpx.empty_like_pinned(_A_lower_diagonal_blocks) @@ -152,6 +156,7 @@ def map_ppobtax_to_pobtars( buffer: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the the boundary blocks of the PPOBTAX algorithm to the reduced system. @@ -178,6 +183,12 @@ def map_ppobtax_to_pobtars( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -228,9 +239,16 @@ def map_ppobtas_to_pobtarss( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTAS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -252,6 +270,7 @@ def aggregate_pobtars( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Aggregate the reduced system. @@ -269,7 +288,14 @@ def aggregate_pobtars( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() _A_diagonal_blocks: ArrayLike = pobtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = pobtars.get("A_lower_diagonal_blocks", None) @@ -306,7 +332,7 @@ def aggregate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. _A_diagonal_blocks.get(out=_A_diagonal_blocks_comm) @@ -317,11 +343,12 @@ def aggregate_pobtars( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -329,9 +356,9 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -339,9 +366,9 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_arrow_blocks_comm.data.ptr, count=count, @@ -349,17 +376,18 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_A_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_A_arrow_tip_block_comm.data.ptr, recvbuf=_A_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -374,7 +402,7 @@ def aggregate_pobtars( ) comm.Allreduce(MPI.IN_PLACE, _A_arrow_tip_block_comm, op=MPI.SUM) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -432,7 +460,7 @@ def aggregate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -449,8 +477,15 @@ def aggregate_pobtarss( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -474,7 +509,7 @@ def aggregate_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the # HOST pinned arrays. @@ -484,10 +519,11 @@ def aggregate_pobtarss( if strategy == "allgather": if _use_nccl(comm): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm[:-a], comm=comm, op="allgather" + arr=_B_comm[:-a], comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_comm[:-a].data.ptr + displacement, recvbuf=_B_comm[:-a].data.ptr, count=count, @@ -495,24 +531,25 @@ def aggregate_pobtarss( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm[-a:], comm=comm, op="allreduce" + arr=_B_comm[-a:], comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_B_comm[-a:].data.ptr, recvbuf=_B_comm[-a:].data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_comm[:-a], ) comm.Allreduce(MPI.IN_PLACE, _B_comm[-a:], op=MPI.SUM) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -547,7 +584,7 @@ def aggregate_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system RHS on the GPU _B.set(arr=_B_comm) @@ -559,10 +596,18 @@ def scatter_pobtars( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() _A_diagonal_blocks: ArrayLike = pobtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = pobtars.get("A_lower_diagonal_blocks", None) @@ -598,6 +643,11 @@ def scatter_pobtars( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -607,7 +657,7 @@ def scatter_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -652,7 +702,7 @@ def scatter_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -671,9 +721,17 @@ def scatter_pobtarss( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() b = A_diagonal_blocks[0].shape[0] a = A_arrow_tip_block.shape[0] @@ -695,6 +753,11 @@ def scatter_pobtarss( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -703,7 +766,7 @@ def scatter_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -727,7 +790,7 @@ def scatter_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _B.set(arr=_B_comm) @@ -747,6 +810,7 @@ def map_pobtars_to_ppobtax( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Map the reduced system back to the original system. @@ -772,6 +836,12 @@ def map_pobtars_to_ppobtax( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -824,9 +894,16 @@ def map_pobtarss_to_ppobtas( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTAS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/pobtrs.py b/src/serinv/wrappers/pobtrs.py index c716317a..e3f48a2e 100644 --- a/src/serinv/wrappers/pobtrs.py +++ b/src/serinv/wrappers/pobtrs.py @@ -24,6 +24,7 @@ def allocate_pobtrs( B: ArrayLike = None, device_streaming: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Allocate the buffers necessary for the reduced system of the PpobtRX algorithms. @@ -47,6 +48,12 @@ def allocate_pobtrs( pobtrs : dict Dictionary containing the reduced system arrays. """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -92,7 +99,7 @@ def allocate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): _A_diagonal_blocks_comm = cpx.empty_like_pinned(_A_diagonal_blocks) _A_lower_diagonal_blocks_comm = cpx.empty_like_pinned(_A_lower_diagonal_blocks) @@ -126,6 +133,7 @@ def map_ppobtx_to_pobtrs( comm: MPI.Comm, buffer: ArrayLike, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: """Map the the boundary blocks of the PpobtX algorithm to the reduced system. @@ -144,6 +152,12 @@ def map_ppobtx_to_pobtrs( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -183,9 +197,16 @@ def map_ppobts_to_pobtrss( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -204,6 +225,7 @@ def aggregate_pobtrs( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Aggregate the reduced system. @@ -215,6 +237,12 @@ def aggregate_pobtrs( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() _A_diagonal_blocks: ArrayLike = pobtrs.get("A_diagonal_blocks", None) @@ -241,7 +269,7 @@ def aggregate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. _A_diagonal_blocks.get(out=_A_diagonal_blocks_comm) @@ -250,11 +278,12 @@ def aggregate_pobtrs( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -262,9 +291,9 @@ def aggregate_pobtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -272,6 +301,7 @@ def aggregate_pobtrs( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -281,7 +311,7 @@ def aggregate_pobtrs( _A_lower_diagonal_blocks_comm, ) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -322,7 +352,7 @@ def aggregate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -336,8 +366,15 @@ def aggregate_pobtrss( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -360,7 +397,7 @@ def aggregate_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the # HOST pinned arrays. @@ -369,11 +406,12 @@ def aggregate_pobtrss( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm, comm=comm, op="allgather" + arr=_B_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_comm.data.ptr + displacement, recvbuf=_B_comm.data.ptr, count=count, @@ -381,12 +419,13 @@ def aggregate_pobtrss( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_comm, ) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -415,7 +454,7 @@ def aggregate_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system RHS on the GPU _B.set(arr=_B_comm) @@ -427,9 +466,16 @@ def scatter_pobtrs( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() _A_diagonal_blocks: ArrayLike = pobtrs.get("A_diagonal_blocks", None) @@ -456,6 +502,11 @@ def scatter_pobtrs( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -465,7 +516,7 @@ def scatter_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -496,7 +547,7 @@ def scatter_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -513,8 +564,15 @@ def scatter_pobtrss( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -536,6 +594,11 @@ def scatter_pobtrss( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -544,7 +607,7 @@ def scatter_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -564,7 +627,7 @@ def scatter_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _B.set(arr=_B_comm) @@ -580,6 +643,7 @@ def map_pobtrs_to_ppobtx( _A_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Map the reduced system back to the original system. @@ -597,6 +661,12 @@ def map_pobtrs_to_ppobtx( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -638,9 +708,16 @@ def map_pobtrss_to_ppobts( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/ppobtaf.py b/src/serinv/wrappers/ppobtaf.py index 64946a3d..240395e0 100644 --- a/src/serinv/wrappers/ppobtaf.py +++ b/src/serinv/wrappers/ppobtaf.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -20,6 +22,7 @@ def ppobtaf( A_lower_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel factorization of a block tridiagonal with arrowhead matrix @@ -62,6 +65,8 @@ def ppobtaf( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + # Check for optional parameters device_streaming: bool = kwargs.get("device_streaming", False) strategy: str = kwargs.get("strategy", "allgather") @@ -133,14 +138,23 @@ def ppobtaf( buffer=buffer, strategy=strategy, comm=comm, + nccl_comm=nccl_comm, ) + MPI.COMM_WORLD.Barrier() + tic = time.perf_counter() aggregate_pobtars( pobtars=pobtars, comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + MPI.COMM_WORLD.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # --- Factorize the reduced system --- pobtars["A_arrow_tip_block"][:] += A_arrow_tip_initial @@ -165,3 +179,5 @@ def ppobtaf( ) comm.Barrier() + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtas.py b/src/serinv/wrappers/ppobtas.py index 132a9d08..2feda2a2 100644 --- a/src/serinv/wrappers/ppobtas.py +++ b/src/serinv/wrappers/ppobtas.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -23,6 +25,7 @@ def ppobtas( L_arrow_tip_block: ArrayLike, B: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -138,16 +141,25 @@ def ppobtas( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Agregate reduced RHS + MPI.COMM_WORLD.Barrier() + tic = time.perf_counter() aggregate_pobtarss( A_diagonal_blocks=L_diagonal_blocks, A_arrow_tip_block=L_arrow_tip_block, pobtars=pobtars, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + MPI.COMM_WORLD.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Add the tip block of the RHS to the aggregated update _B[-a:] += B_tip_initial @@ -180,6 +192,7 @@ def ppobtas( pobtars=pobtars, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Map solution of reduced RHS to RHS @@ -190,6 +203,7 @@ def ppobtas( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel backward solve @@ -213,3 +227,5 @@ def ppobtas( buffer=buffer, trans="C", ) + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtasi.py b/src/serinv/wrappers/ppobtasi.py index 12e23e14..96b67c20 100644 --- a/src/serinv/wrappers/ppobtasi.py +++ b/src/serinv/wrappers/ppobtasi.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import pobtasi @@ -20,6 +19,7 @@ def ppobtasi( L_lower_arrow_blocks: ArrayLike, L_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -114,6 +114,7 @@ def ppobtasi( comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) # Map result of the reduced system back to the original system @@ -129,6 +130,7 @@ def ppobtasi( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel selected inversion of the original system diff --git a/src/serinv/wrappers/ppobtf.py b/src/serinv/wrappers/ppobtf.py index 2d680c96..b8f8e839 100644 --- a/src/serinv/wrappers/ppobtf.py +++ b/src/serinv/wrappers/ppobtf.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -18,6 +20,7 @@ def ppobtf( A_diagonal_blocks: ArrayLike, A_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel factorization of a block tridiagonal with arrowhead matrix @@ -56,6 +59,8 @@ def ppobtf( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + # Check for optional parameters device_streaming: bool = kwargs.get("device_streaming", False) strategy: str = kwargs.get("strategy", "allgather") @@ -106,14 +111,23 @@ def ppobtf( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) + MPI.COMM_WORLD.Barrier() + tic = time.perf_counter() aggregate_pobtrs( pobtrs=pobtrs, comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + MPI.COMM_WORLD.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # --- Factorize the reduced system --- if strategy == "gather-scatter": @@ -134,3 +148,5 @@ def ppobtf( ) comm.Barrier() + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobts.py b/src/serinv/wrappers/ppobts.py index 396577f1..bef5b48e 100644 --- a/src/serinv/wrappers/ppobts.py +++ b/src/serinv/wrappers/ppobts.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -21,6 +23,7 @@ def ppobts( L_lower_diagonal_blocks: ArrayLike, B: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -119,14 +122,24 @@ def ppobts( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + # Agregate reduced RHS + MPI.COMM_WORLD.Barrier() + tic = time.perf_counter() aggregate_pobtrss( A_diagonal_blocks=L_diagonal_blocks, pobtrs=pobtrs, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + MPI.COMM_WORLD.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Solve RHS FWD/BWD if strategy == "allgather": @@ -151,6 +164,7 @@ def ppobts( pobtrs=pobtrs, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Map solution of reduced RHS to RHS @@ -160,6 +174,7 @@ def ppobts( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel backward solve @@ -179,3 +194,5 @@ def ppobts( buffer=buffer, trans="C", ) + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtsi.py b/src/serinv/wrappers/ppobtsi.py index fde320fc..72c7354d 100644 --- a/src/serinv/wrappers/ppobtsi.py +++ b/src/serinv/wrappers/ppobtsi.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import pobtsi @@ -18,6 +17,7 @@ def ppobtsi( L_diagonal_blocks: ArrayLike, L_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -100,6 +100,7 @@ def ppobtsi( comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) # Map result of the reduced system back to the original system @@ -111,6 +112,7 @@ def ppobtsi( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel selected inversion of the original system From 2d66457349d4dafc3ec4aa54153898afec2a086d Mon Sep 17 00:00:00 2001 From: vincent-maillou Date: Sat, 5 Apr 2025 22:41:34 +0200 Subject: [PATCH 003/157] removed explicit comm world call --- src/serinv/wrappers/pddbtasc.py | 4 ++-- src/serinv/wrappers/pddbtsc.py | 4 ++-- src/serinv/wrappers/ppobtaf.py | 4 ++-- src/serinv/wrappers/ppobtas.py | 4 ++-- src/serinv/wrappers/ppobtf.py | 4 ++-- src/serinv/wrappers/ppobts.py | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/serinv/wrappers/pddbtasc.py b/src/serinv/wrappers/pddbtasc.py index 6c756c92..27335195 100644 --- a/src/serinv/wrappers/pddbtasc.py +++ b/src/serinv/wrappers/pddbtasc.py @@ -207,7 +207,7 @@ def pddbtasc( nccl_comm=nccl_comm, ) - MPI.COMM_WORLD.Barrier() + comm.Barrier() tic = time.perf_counter() aggregate_ddbtars( ddbtars=ddbtars, @@ -218,7 +218,7 @@ def pddbtasc( ) if xp.__name__ == "cupy": xp.cuda.runtime.deviceSynchronize() - MPI.COMM_WORLD.Barrier() + comm.Barrier() toc = time.perf_counter() elapsed = toc - tic diff --git a/src/serinv/wrappers/pddbtsc.py b/src/serinv/wrappers/pddbtsc.py index c1ec3b72..e9f1eb9e 100644 --- a/src/serinv/wrappers/pddbtsc.py +++ b/src/serinv/wrappers/pddbtsc.py @@ -168,7 +168,7 @@ def pddbtsc( nccl_comm=nccl_comm, ) - MPI.COMM_WORLD.Barrier() + comm.Barrier() tic = time.perf_counter() aggregate_ddbtrs( ddbtrs=ddbtrs, @@ -179,7 +179,7 @@ def pddbtsc( ) if xp.__name__ == "cupy": xp.cuda.runtime.deviceSynchronize() - MPI.COMM_WORLD.Barrier() + comm.Barrier() toc = time.perf_counter() elapsed = toc - tic diff --git a/src/serinv/wrappers/ppobtaf.py b/src/serinv/wrappers/ppobtaf.py index 240395e0..2122b88a 100644 --- a/src/serinv/wrappers/ppobtaf.py +++ b/src/serinv/wrappers/ppobtaf.py @@ -141,7 +141,7 @@ def ppobtaf( nccl_comm=nccl_comm, ) - MPI.COMM_WORLD.Barrier() + comm.Barrier() tic = time.perf_counter() aggregate_pobtars( pobtars=pobtars, @@ -152,7 +152,7 @@ def ppobtaf( ) if xp.__name__ == "cupy": xp.cuda.runtime.deviceSynchronize() - MPI.COMM_WORLD.Barrier() + comm.Barrier() toc = time.perf_counter() elapsed = toc - tic diff --git a/src/serinv/wrappers/ppobtas.py b/src/serinv/wrappers/ppobtas.py index 2feda2a2..52896530 100644 --- a/src/serinv/wrappers/ppobtas.py +++ b/src/serinv/wrappers/ppobtas.py @@ -145,7 +145,7 @@ def ppobtas( ) # Agregate reduced RHS - MPI.COMM_WORLD.Barrier() + comm.Barrier() tic = time.perf_counter() aggregate_pobtarss( A_diagonal_blocks=L_diagonal_blocks, @@ -157,7 +157,7 @@ def ppobtas( ) if xp.__name__ == "cupy": xp.cuda.runtime.deviceSynchronize() - MPI.COMM_WORLD.Barrier() + comm.Barrier() toc = time.perf_counter() elapsed = toc - tic diff --git a/src/serinv/wrappers/ppobtf.py b/src/serinv/wrappers/ppobtf.py index b8f8e839..f1b44956 100644 --- a/src/serinv/wrappers/ppobtf.py +++ b/src/serinv/wrappers/ppobtf.py @@ -114,7 +114,7 @@ def ppobtf( nccl_comm=nccl_comm, ) - MPI.COMM_WORLD.Barrier() + comm.Barrier() tic = time.perf_counter() aggregate_pobtrs( pobtrs=pobtrs, @@ -125,7 +125,7 @@ def ppobtf( ) if xp.__name__ == "cupy": xp.cuda.runtime.deviceSynchronize() - MPI.COMM_WORLD.Barrier() + comm.Barrier() toc = time.perf_counter() elapsed = toc - tic diff --git a/src/serinv/wrappers/ppobts.py b/src/serinv/wrappers/ppobts.py index bef5b48e..86906883 100644 --- a/src/serinv/wrappers/ppobts.py +++ b/src/serinv/wrappers/ppobts.py @@ -126,7 +126,7 @@ def ppobts( ) # Agregate reduced RHS - MPI.COMM_WORLD.Barrier() + comm.Barrier() tic = time.perf_counter() aggregate_pobtrss( A_diagonal_blocks=L_diagonal_blocks, @@ -137,7 +137,7 @@ def ppobts( ) if xp.__name__ == "cupy": xp.cuda.runtime.deviceSynchronize() - MPI.COMM_WORLD.Barrier() + comm.Barrier() toc = time.perf_counter() elapsed = toc - tic From 791daa7bc32eca05c368d013cfaf91f9e5051e66 Mon Sep 17 00:00:00 2001 From: vincent-maillou Date: Sat, 5 Apr 2025 23:58:04 +0200 Subject: [PATCH 004/157] added nccl synch --- src/serinv/wrappers/pobtars.py | 3 +++ src/serinv/wrappers/pobtrs.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/serinv/wrappers/pobtars.py b/src/serinv/wrappers/pobtars.py index dffa86e8..ee6f26bb 100644 --- a/src/serinv/wrappers/pobtars.py +++ b/src/serinv/wrappers/pobtars.py @@ -345,6 +345,7 @@ def aggregate_pobtars( if strategy == "allgather": if _use_nccl(communicator): # --- Use NCCL --- + cp.cuda.runtime.deviceSynchronize() count, displacement, datatype = _get_nccl_parameters( arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) @@ -386,6 +387,8 @@ def aggregate_pobtars( op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) + cp.cuda.runtime.deviceSynchronize() + comm.Barrier() else: # --- Use MPI --- comm.Allgather( diff --git a/src/serinv/wrappers/pobtrs.py b/src/serinv/wrappers/pobtrs.py index e3f48a2e..391953d4 100644 --- a/src/serinv/wrappers/pobtrs.py +++ b/src/serinv/wrappers/pobtrs.py @@ -280,6 +280,7 @@ def aggregate_pobtrs( if strategy == "allgather": if _use_nccl(communicator): # --- Use NCCL --- + cp.cuda.runtime.deviceSynchronize() count, displacement, datatype = _get_nccl_parameters( arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) @@ -300,6 +301,8 @@ def aggregate_pobtrs( datatype=datatype, stream=cp.cuda.Stream.null.ptr, ) + cp.cuda.runtime.deviceSynchronize() + comm.Barrier() else: # --- Use MPI --- comm.Allgather( From 9dca6a88c0a9c8ff684b0f0935654d6b8a4d6481 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 27 May 2025 14:51:49 +0000 Subject: [PATCH 005/157] first implementation of trsm from cupy and scipy for left and right hand side --- src/serinv/utils/trsm_solve_device.py | 102 +++++++++++++++++++++ src/serinv/utils/trsm_solve_host.py | 122 ++++++++++++++++++++++++++ 2 files changed, 224 insertions(+) create mode 100644 src/serinv/utils/trsm_solve_device.py create mode 100644 src/serinv/utils/trsm_solve_host.py diff --git a/src/serinv/utils/trsm_solve_device.py b/src/serinv/utils/trsm_solve_device.py new file mode 100644 index 00000000..9d6bdb43 --- /dev/null +++ b/src/serinv/utils/trsm_solve_device.py @@ -0,0 +1,102 @@ +import numpy + +from cupy.cuda import cublas +from cupy.cuda import device +from cupy.linalg import _util + + +def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, check_finite=False, aplha = 1., side=0): + """Solve the equation a x = b for x, assuming a is a triangular matrix. + + Args: + a (cupy.ndarray): The matrix with dimension ``(M, M)``. + b (cupy.ndarray): The matrix with dimension ``(M,)`` or + ``(M, N)``. + lower (bool): Use only data contained in the lower triangle of ``a``. + Default is to use upper triangle. + trans (0, 1, 2, 'N', 'T' or 'C'): Type of system to solve: + + - *'0'* or *'N'* -- :math:`a x = b` + - *'1'* or *'T'* -- :math:`a^T x = b` + - *'2'* or *'C'* -- :math:`a^H x = b` + + unit_diagonal (bool): If ``True``, diagonal elements of ``a`` are + assumed to be 1 and will not be referenced. + overwrite_b (bool): Allow overwriting data in b (may enhance + performance) + check_finite (bool): Whether to check that the input matrices contain + only finite numbers. Disabling may give a performance gain, but may + result in problems (crashes, non-termination) if the inputs do + contain infinities or NaNs. + + Returns: + cupy.ndarray: + The matrix with dimension ``(M,)`` or ``(M, N)``. + + .. seealso:: :func:`scipy.linalg.solve_triangular` + """ + + _util._assert_cupy_array(a, b) + + if len(a.shape) != 2 or a.shape[0] != a.shape[1]: + raise ValueError('expected square matrix') + if len(a) != len(b): + raise ValueError('incompatible dimensions') + + # Cast to float32 or float64 + if a.dtype.char in 'fd': + dtype = a.dtype + else: + dtype = numpy.promote_types(a.dtype.char, 'f') + + a = cupy.array(a, dtype=dtype, order='F', copy=False) + b = cupy.array(b, dtype=dtype, order='F', copy=(not overwrite_b)) + + if check_finite: + if a.dtype.kind == 'f' and not cupy.isfinite(a).all(): + raise ValueError( + 'array must not contain infs or NaNs') + if b.dtype.kind == 'f' and not cupy.isfinite(b).all(): + raise ValueError( + 'array must not contain infs or NaNs') + + m, n = (b.size, 1) if b.ndim == 1 else b.shape + cublas_handle = device.get_cublas_handle() + + if dtype == 'f': + trsm = cublas.strsm + elif dtype == 'd': + trsm = cublas.dtrsm + elif dtype == 'F': + trsm = cublas.ctrsm + else: # dtype == 'D' + trsm = cublas.ztrsm + one = numpy.array(1, dtype=dtype) + + if lower: + uplo = cublas.CUBLAS_FILL_MODE_LOWER + else: + uplo = cublas.CUBLAS_FILL_MODE_UPPER + + if trans == 'N': + trans = cublas.CUBLAS_OP_N + elif trans == 'T': + trans = cublas.CUBLAS_OP_T + elif trans == 'C': + trans = cublas.CUBLAS_OP_C + + if unit_diagonal: + diag = cublas.CUBLAS_DIAG_UNIT + else: + diag = cublas.CUBLAS_DIAG_NON_UNIT + + if side: + blas_side = cublas.CUBLAS_SIDE_RIGHT + else: + blas_side = cublas.CUBLAS_SIDE_LEFT + + trsm( + cublas_handle, blas_side, uplo, + trans, diag, + m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) \ No newline at end of file diff --git a/src/serinv/utils/trsm_solve_host.py b/src/serinv/utils/trsm_solve_host.py new file mode 100644 index 00000000..820770e4 --- /dev/null +++ b/src/serinv/utils/trsm_solve_host.py @@ -0,0 +1,122 @@ +import numpy as np + + +from scipy.linalg.blas import get_blas_funcs +from scipy.linalg._misc import _datacopied +from scipy.linalg._decomp import _asarray_validated + +def solve_triangular_host(a, b, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, check_finite=True, side=0): + """ + Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. + + Parameters + ---------- + a : (M, M) array_like + A triangular matrix + b : (M,) or (M, N) array_like + Right-hand side matrix in ``a x = b`` + lower : bool, optional + Use only data contained in the lower triangle of `a`. + Default is to use upper triangle. + trans : {0, 1, 2, 'N', 'T', 'C'}, optional + Type of system to solve: + + ======== ========= + trans system + ======== ========= + 0 or 'N' a x = b + 1 or 'T' a^T x = b + 2 or 'C' a^H x = b + ======== ========= + unit_diagonal : bool, optional + If True, diagonal elements of `a` are assumed to be 1 and + will not be referenced. + overwrite_b : bool, optional + Allow overwriting data in `b` (may enhance performance) + check_finite : bool, optional + Whether to check that the input matrices contain only finite numbers. + Disabling may give a performance gain, but may result in problems + (crashes, non-termination) if the inputs do contain infinities or NaNs. + + Returns + ------- + x : (M,) or (M, N) ndarray + Solution to the system ``a x = b``. Shape of return matches `b`. + + Raises + ------ + LinAlgError + If `a` is singular + + Notes + ----- + .. versionadded:: 0.9.0 + + Examples + -------- + Solve the lower triangular system a x = b, where:: + + [3 0 0 0] [4] + a = [2 1 0 0] b = [2] + [1 0 1 0] [4] + [1 1 1 1] [2] + + >>> import numpy as np + >>> from scipy.linalg import solve_triangular + >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]]) + >>> b = np.array([4, 2, 4, 2]) + >>> x = solve_triangular(a, b, lower=True) + >>> x + array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333]) + >>> a.dot(x) # Check the result + array([ 4., 2., 4., 2.]) + + """ + + a1 = _asarray_validated(a, check_finite=check_finite) + b1 = _asarray_validated(b, check_finite=check_finite) + + if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: + raise ValueError('expected square matrix') + + if a1.shape[0] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + # accommodate empty arrays + if b1.size == 0: + dt_nonempty = solve_triangular_host( + np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) + ).dtype + return np.empty_like(b1, dtype=dt_nonempty) + + overwrite_b = overwrite_b or _datacopied(b1, b) + + x = _solve_triangular(a1, b1, trans, lower, unit_diagonal, overwrite_b, side) + return x + + +# solve_triangular without the input validation +def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, side=0): + + trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) + trsm, = get_blas_funcs(('trsm',), (a1, b1)) + + if a1.dtype.char in 'fd': + dtype = a1.dtype + else: + dtype = np.promote_types(a1.dtype.char, 'f') + + one = np.array(1, dtype=dtype) + alpha = one.ctypes.data + + if a1.flags.f_contiguous or trans == 2: + x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=lower, + trans_a=trans, diag=unit_diagonal, side=side) + else: + # transposed system is solved since trtrs expects Fortran ordering + x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, + trans_a=not trans, diag=unit_diagonal, side=side) + + return x From 7e3884e7a03cab3653f2d5f94f27cc0b79bd53bb Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:03:11 +0000 Subject: [PATCH 006/157] matmul implementation --- .../utils/{trsm_solve_host.py => matmul.py} | 51 ++-- src/serinv/utils/trsm.py | 246 ++++++++++++++++++ src/serinv/utils/trsm_solve_device.py | 102 -------- 3 files changed, 279 insertions(+), 120 deletions(-) rename src/serinv/utils/{trsm_solve_host.py => matmul.py} (73%) create mode 100644 src/serinv/utils/trsm.py delete mode 100644 src/serinv/utils/trsm_solve_device.py diff --git a/src/serinv/utils/trsm_solve_host.py b/src/serinv/utils/matmul.py similarity index 73% rename from src/serinv/utils/trsm_solve_host.py rename to src/serinv/utils/matmul.py index 820770e4..aa68ab31 100644 --- a/src/serinv/utils/trsm_solve_host.py +++ b/src/serinv/utils/matmul.py @@ -1,12 +1,31 @@ -import numpy as np +from serinv import _get_module_from_array +import numpy as np +from numpy.linalg import matmul from scipy.linalg.blas import get_blas_funcs from scipy.linalg._misc import _datacopied from scipy.linalg._decomp import _asarray_validated -def solve_triangular_host(a, b, trans=0, lower=False, unit_diagonal=False, - overwrite_b=False, check_finite=True, side=0): +try: + import cupy as cp + from cupy.cublas import gemm +except (ImportError, ImportWarning, ModuleNotFoundError): + pass + +def serinv_matmul (a, b): + """Wrapper to call GeMM for host or device""" + xp, la = _get_module_from_array(a) + + if xp == np: + return matmul(a, b) + elif xp == cp: + return gemm('N', 'N', a, b) + else: + ModuleNotFoundError("Unknown Module") + + +def matmul_gemm_host(a, b, trans_a=0, trans_b=0, overwrite_c=0, check_finite=False): """ Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. @@ -85,23 +104,21 @@ def solve_triangular_host(a, b, trans=0, lower=False, unit_diagonal=False, # accommodate empty arrays if b1.size == 0: - dt_nonempty = solve_triangular_host( + dt_nonempty = matmul_gemm_host( np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) ).dtype return np.empty_like(b1, dtype=dt_nonempty) - overwrite_b = overwrite_b or _datacopied(b1, b) - - x = _solve_triangular(a1, b1, trans, lower, unit_diagonal, overwrite_b, side) + x = _solve_triangular(a1, b1, trans_a, trans_b, overwrite_c) return x # solve_triangular without the input validation -def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, - overwrite_b=False, side=0): +def _solve_triangular(a1, b1, trans_a=0, trans_b=0, overwrite_c=0): - trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) - trsm, = get_blas_funcs(('trsm',), (a1, b1)) + trans_a = {'N': 0, 'T': 1, 'C': 2}.get(trans_a, trans_a) + trans_b = {'N': 0, 'T': 1, 'C': 2}.get(trans_b, trans_b) + gemm, = get_blas_funcs(('gemm',), (a1, b1)) if a1.dtype.char in 'fd': dtype = a1.dtype @@ -109,14 +126,12 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, dtype = np.promote_types(a1.dtype.char, 'f') one = np.array(1, dtype=dtype) + zero =np.array(0, dtype=dtype) alpha = one.ctypes.data + beta = zero.ctypes.data - if a1.flags.f_contiguous or trans == 2: - x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=lower, - trans_a=trans, diag=unit_diagonal, side=side) - else: - # transposed system is solved since trtrs expects Fortran ordering - x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, - trans_a=not trans, diag=unit_diagonal, side=side) + + x = gemm(alpha, a1.T, b1.T, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) + return x diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py new file mode 100644 index 00000000..159f32a4 --- /dev/null +++ b/src/serinv/utils/trsm.py @@ -0,0 +1,246 @@ +import numpy as np + +from serinv import _get_module_from_array + +from scipy.linalg.blas import get_blas_funcs +from scipy.linalg._misc import _datacopied +from scipy.linalg._decomp import _asarray_validated + +try: + import cupy as cp + from cupy.cuda import cublas + from cupy.cuda import device + from cupy.linalg import _util +except (ImportError, ImportWarning, ModuleNotFoundError): + pass + +def serinv_solve_triangular(a, b, trans=0, lower = False, unit_diagonal=False, + overwrite_b=False, check_finite=False, side=0): + """Wrapper for the trsm function to call depending on wheter the solve happens on the host or the device + + For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept + plus the side parameter which can either be 0 or 1 for left or right hand side + """ + xp = _get_module_from_array(a) + + if xp == np: + return solve_triangular_host(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) + elif xp == cp: + return solve_triangular_device(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) + else: + ModuleNotFoundError("Unknown Module") + + + +def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, check_finite=False, side=0): + """Solve the equation a x = b for x, assuming a is a triangular matrix. + + Args: + a (cupy.ndarray): The matrix with dimension ``(M, M)``. + b (cupy.ndarray): The matrix with dimension ``(M,)`` or + ``(M, N)``. + lower (bool): Use only data contained in the lower triangle of ``a``. + Default is to use upper triangle. + trans (0, 1, 2, 'N', 'T' or 'C'): Type of system to solve: + + - *'0'* or *'N'* -- :math:`a x = b` + - *'1'* or *'T'* -- :math:`a^T x = b` + - *'2'* or *'C'* -- :math:`a^H x = b` + + unit_diagonal (bool): If ``True``, diagonal elements of ``a`` are + assumed to be 1 and will not be referenced. + overwrite_b (bool): Allow overwriting data in b (may enhance + performance) + check_finite (bool): Whether to check that the input matrices contain + only finite numbers. Disabling may give a performance gain, but may + result in problems (crashes, non-termination) if the inputs do + contain infinities or NaNs. + + Returns: + cupy.ndarray: + The matrix with dimension ``(M,)`` or ``(M, N)``. + + .. seealso:: :func:`scipy.linalg.solve_triangular` + """ + + _util._assert_cupy_array(a, b) + + if len(a.shape) != 2 or a.shape[0] != a.shape[1]: + raise ValueError('expected square matrix') + if len(a) != len(b): + raise ValueError('incompatible dimensions') + + # Cast to float32 or float64 + if a.dtype.char in 'fd': + dtype = a.dtype + else: + dtype = np.promote_types(a.dtype.char, 'f') + + a = cp.array(a, dtype=dtype, order='F', copy=False) + b = cp.array(b, dtype=dtype, order='F', copy=(not overwrite_b)) + + if check_finite: + if a.dtype.kind == 'f' and not cp.isfinite(a).all(): + raise ValueError( + 'array must not contain infs or NaNs') + if b.dtype.kind == 'f' and not cp.isfinite(b).all(): + raise ValueError( + 'array must not contain infs or NaNs') + + m, n = (b.size, 1) if b.ndim == 1 else b.shape + cublas_handle = device.get_cublas_handle() + + if dtype == 'f': + trsm = cublas.strsm + elif dtype == 'd': + trsm = cublas.dtrsm + elif dtype == 'F': + trsm = cublas.ctrsm + else: # dtype == 'D' + trsm = cublas.ztrsm + one = np.array(1, dtype=dtype) + + if lower: + uplo = cublas.CUBLAS_FILL_MODE_LOWER + else: + uplo = cublas.CUBLAS_FILL_MODE_UPPER + + if trans == 'N': + trans = cublas.CUBLAS_OP_N + elif trans == 'T': + trans = cublas.CUBLAS_OP_T + elif trans == 'C': + trans = cublas.CUBLAS_OP_C + + if unit_diagonal: + diag = cublas.CUBLAS_DIAG_UNIT + else: + diag = cublas.CUBLAS_DIAG_NON_UNIT + + if side: + blas_side = cublas.CUBLAS_SIDE_RIGHT + else: + blas_side = cublas.CUBLAS_SIDE_LEFT + + trsm( + cublas_handle, blas_side, uplo, + trans, diag, + m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) + + +def solve_triangular_host(a, b, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, check_finite=True, side=0): + """ + Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. + + Parameters + ---------- + a : (M, M) array_like + A triangular matrix + b : (M,) or (M, N) array_like + Right-hand side matrix in ``a x = b`` + lower : bool, optional + Use only data contained in the lower triangle of `a`. + Default is to use upper triangle. + trans : {0, 1, 2, 'N', 'T', 'C'}, optional + Type of system to solve: + + ======== ========= + trans system + ======== ========= + 0 or 'N' a x = b + 1 or 'T' a^T x = b + 2 or 'C' a^H x = b + ======== ========= + unit_diagonal : bool, optional + If True, diagonal elements of `a` are assumed to be 1 and + will not be referenced. + overwrite_b : bool, optional + Allow overwriting data in `b` (may enhance performance) + check_finite : bool, optional + Whether to check that the input matrices contain only finite numbers. + Disabling may give a performance gain, but may result in problems + (crashes, non-termination) if the inputs do contain infinities or NaNs. + + Returns + ------- + x : (M,) or (M, N) ndarray + Solution to the system ``a x = b``. Shape of return matches `b`. + + Raises + ------ + LinAlgError + If `a` is singular + + Notes + ----- + .. versionadded:: 0.9.0 + + Examples + -------- + Solve the lower triangular system a x = b, where:: + + [3 0 0 0] [4] + a = [2 1 0 0] b = [2] + [1 0 1 0] [4] + [1 1 1 1] [2] + + >>> import numpy as np + >>> from scipy.linalg import solve_triangular + >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]]) + >>> b = np.array([4, 2, 4, 2]) + >>> x = solve_triangular(a, b, lower=True) + >>> x + array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333]) + >>> a.dot(x) # Check the result + array([ 4., 2., 4., 2.]) + + """ + + a1 = _asarray_validated(a, check_finite=check_finite) + b1 = _asarray_validated(b, check_finite=check_finite) + + if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: + raise ValueError('expected square matrix') + + if a1.shape[0] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + # accommodate empty arrays + if b1.size == 0: + dt_nonempty = solve_triangular_host( + np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) + ).dtype + return np.empty_like(b1, dtype=dt_nonempty) + + overwrite_b = overwrite_b or _datacopied(b1, b) + + x = _solve_triangular(a1, b1, trans, lower, unit_diagonal, overwrite_b, side) + return x + + +# solve_triangular without the input validation +def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, side=0): + + trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) + trsm, = get_blas_funcs(('trsm',), (a1, b1)) + + if a1.dtype.char in 'fd': + dtype = a1.dtype + else: + dtype = np.promote_types(a1.dtype.char, 'f') + + one = np.array(1, dtype=dtype) + alpha = one.ctypes.data + + if a1.flags.f_contiguous or trans == 2: + x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=lower, + trans_a=trans, diag=unit_diagonal, side=side) + else: + # transposed system is solved since trtrs expects Fortran ordering + x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, + trans_a=not trans, diag=unit_diagonal, side=side) + + return x \ No newline at end of file diff --git a/src/serinv/utils/trsm_solve_device.py b/src/serinv/utils/trsm_solve_device.py deleted file mode 100644 index 9d6bdb43..00000000 --- a/src/serinv/utils/trsm_solve_device.py +++ /dev/null @@ -1,102 +0,0 @@ -import numpy - -from cupy.cuda import cublas -from cupy.cuda import device -from cupy.linalg import _util - - -def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, - overwrite_b=False, check_finite=False, aplha = 1., side=0): - """Solve the equation a x = b for x, assuming a is a triangular matrix. - - Args: - a (cupy.ndarray): The matrix with dimension ``(M, M)``. - b (cupy.ndarray): The matrix with dimension ``(M,)`` or - ``(M, N)``. - lower (bool): Use only data contained in the lower triangle of ``a``. - Default is to use upper triangle. - trans (0, 1, 2, 'N', 'T' or 'C'): Type of system to solve: - - - *'0'* or *'N'* -- :math:`a x = b` - - *'1'* or *'T'* -- :math:`a^T x = b` - - *'2'* or *'C'* -- :math:`a^H x = b` - - unit_diagonal (bool): If ``True``, diagonal elements of ``a`` are - assumed to be 1 and will not be referenced. - overwrite_b (bool): Allow overwriting data in b (may enhance - performance) - check_finite (bool): Whether to check that the input matrices contain - only finite numbers. Disabling may give a performance gain, but may - result in problems (crashes, non-termination) if the inputs do - contain infinities or NaNs. - - Returns: - cupy.ndarray: - The matrix with dimension ``(M,)`` or ``(M, N)``. - - .. seealso:: :func:`scipy.linalg.solve_triangular` - """ - - _util._assert_cupy_array(a, b) - - if len(a.shape) != 2 or a.shape[0] != a.shape[1]: - raise ValueError('expected square matrix') - if len(a) != len(b): - raise ValueError('incompatible dimensions') - - # Cast to float32 or float64 - if a.dtype.char in 'fd': - dtype = a.dtype - else: - dtype = numpy.promote_types(a.dtype.char, 'f') - - a = cupy.array(a, dtype=dtype, order='F', copy=False) - b = cupy.array(b, dtype=dtype, order='F', copy=(not overwrite_b)) - - if check_finite: - if a.dtype.kind == 'f' and not cupy.isfinite(a).all(): - raise ValueError( - 'array must not contain infs or NaNs') - if b.dtype.kind == 'f' and not cupy.isfinite(b).all(): - raise ValueError( - 'array must not contain infs or NaNs') - - m, n = (b.size, 1) if b.ndim == 1 else b.shape - cublas_handle = device.get_cublas_handle() - - if dtype == 'f': - trsm = cublas.strsm - elif dtype == 'd': - trsm = cublas.dtrsm - elif dtype == 'F': - trsm = cublas.ctrsm - else: # dtype == 'D' - trsm = cublas.ztrsm - one = numpy.array(1, dtype=dtype) - - if lower: - uplo = cublas.CUBLAS_FILL_MODE_LOWER - else: - uplo = cublas.CUBLAS_FILL_MODE_UPPER - - if trans == 'N': - trans = cublas.CUBLAS_OP_N - elif trans == 'T': - trans = cublas.CUBLAS_OP_T - elif trans == 'C': - trans = cublas.CUBLAS_OP_C - - if unit_diagonal: - diag = cublas.CUBLAS_DIAG_UNIT - else: - diag = cublas.CUBLAS_DIAG_NON_UNIT - - if side: - blas_side = cublas.CUBLAS_SIDE_RIGHT - else: - blas_side = cublas.CUBLAS_SIDE_LEFT - - trsm( - cublas_handle, blas_side, uplo, - trans, diag, - m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) \ No newline at end of file From 08a768702b132cbc87bdfdb18796a427bbb5a348 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:24:41 +0000 Subject: [PATCH 007/157] used serinv_matmul once for testing --- src/serinv/algs/pobtas.py | 6 ++++-- src/serinv/utils/__init__.py | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index bab2a911..13fa4628 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -4,8 +4,10 @@ from serinv import ( ArrayLike, _get_module_from_array, + ) +from serinv.utils import serinv_matmul, serinv_solve_triangular def pobtas( L_diagonal_blocks: ArrayLike, @@ -86,9 +88,9 @@ def _pobtas( lower=True, ) - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= serinv_matmul( L_lower_diagonal_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + , B[i * diag_blocksize : (i + 1) * diag_blocksize] ) B[-arrow_blocksize:] -= ( diff --git a/src/serinv/utils/__init__.py b/src/serinv/utils/__init__.py index 4079f81c..00a5a351 100644 --- a/src/serinv/utils/__init__.py +++ b/src/serinv/utils/__init__.py @@ -8,6 +8,9 @@ from serinv.utils.pobtx import allocate_pobtx_permutation_buffers from serinv.utils.pobtax import allocate_pobtax_permutation_buffers +from serinv.utils.matmul import serinv_matmul +from serinv.utils.trsm import serinv_solve_triangular + __all__ = [ "check_block_dd", "check_ddbta", @@ -15,4 +18,6 @@ "allocate_ddbtax_permutation_buffers", "allocate_pobtx_permutation_buffers", "allocate_pobtax_permutation_buffers", + "serinv_matmul", + "serinv_solve_triangluar" ] From dfef1722d5724c748500fb65a17877104b84afd8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:26:52 +0000 Subject: [PATCH 008/157] used serinv_solve_triangular once for testing --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 13fa4628..72a4c751 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -82,10 +82,10 @@ def _pobtas( if trans == "N": # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( + B[i * diag_blocksize : (i + 1) * diag_blocksize] = serinv_solve_triangular( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, + lower=True, side=0 ) B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= serinv_matmul( From 0b9ceb7998929dc24ed01429fb6b5e5f33cc9039 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:41:41 +0000 Subject: [PATCH 009/157] updated solve_triangular_deive to newer cupy implementation --- src/serinv/utils/trsm.py | 142 ++++++++++++++++++++++++++++++--------- 1 file changed, 111 insertions(+), 31 deletions(-) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index 159f32a4..a2fac8d5 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -66,10 +66,23 @@ def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, _util._assert_cupy_array(a, b) - if len(a.shape) != 2 or a.shape[0] != a.shape[1]: - raise ValueError('expected square matrix') - if len(a) != len(b): - raise ValueError('incompatible dimensions') + if a.ndim == 2: + if a.shape[0] != a.shape[1]: + raise ValueError('expected square matrix') + if len(a) != len(b): + raise ValueError('incompatible dimensions') + batch_count = 0 + elif a.ndim > 2: + if a.shape[-1] != a.shape[-2]: + raise ValueError('expected a batch of square matrices') + if a.shape[:-2] != b.shape[:a.ndim - 2]: + raise ValueError('incompatible batch count') + if b.ndim < a.ndim - 1 or a.shape[-2] != b.shape[a.ndim - 2]: + raise ValueError('incompatible dimensions') + batch_count = math.prod(a.shape[:-2]) + else: + raise ValueError( + 'expected one square matrix or a batch of square matrices') # Cast to float32 or float64 if a.dtype.char in 'fd': @@ -77,9 +90,6 @@ def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, else: dtype = np.promote_types(a.dtype.char, 'f') - a = cp.array(a, dtype=dtype, order='F', copy=False) - b = cp.array(b, dtype=dtype, order='F', copy=(not overwrite_b)) - if check_finite: if a.dtype.kind == 'f' and not cp.isfinite(a).all(): raise ValueError( @@ -88,17 +98,69 @@ def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, raise ValueError( 'array must not contain infs or NaNs') - m, n = (b.size, 1) if b.ndim == 1 else b.shape - cublas_handle = device.get_cublas_handle() + if batch_count: + m, n = b.shape[-2:] if b.ndim == a.ndim else (b.shape[-1], 1) + + a_new_shape = (batch_count, m, m) + b_shape = b.shape + b_data_ptr = b.data.ptr + # trsm receives Fortran array, but we want zero copy + if trans == 'N' or trans == cublas.CUBLAS_OP_N: + # normal Fortran upper == transpose C lower + trans = cublas.CUBLAS_OP_T + lower = not lower + a = cp.ascontiguousarray(a.reshape(*a_new_shape), dtype=dtype) + elif trans == 'T' or trans == cublas.CUBLAS_OP_T: + # transpose Fortran upper == normal C lower + trans = cublas.CUBLAS_OP_N + lower = not lower + a = cp.ascontiguousarray(a.reshape(*a_new_shape), dtype=dtype) + elif trans == 'C' or trans == cublas.CUBLAS_OP_C: + if dtype == 'f' or dtype == 'd': + # real numbers + # Hermitian Fortran upper == transpose Fortran upper + # == normal C lower + trans = cublas.CUBLAS_OP_N + lower = not lower + a = cp.ascontiguousarray(a.reshape(*a_new_shape), + dtype=dtype) + else: + # complex numbers + trans = cublas.CUBLAS_OP_C + a = cp.ascontiguousarray( + a.reshape(*a_new_shape).transpose(0, 2, 1), dtype=dtype) + else: # know nothing about `trans`, just convert C to Fortran + a = cp.ascontiguousarray( + a.reshape(*a_new_shape).transpose(0, 2, 1), dtype=dtype) + b = cp.ascontiguousarray( + b.reshape(batch_count, m, n).transpose(0, 2, 1), dtype=dtype) + if b.data.ptr == b_data_ptr and not overwrite_b: + b = b.copy() + + start = a.data.ptr + step = m * m * a.itemsize + stop = start + step * batch_count + a_array = cp.arange(start, stop, step, dtype=cp.uintp) + + start = b.data.ptr + step = m * n * b.itemsize + stop = start + step * batch_count + b_array = cp.arange(start, stop, step, dtype=cp.uintp) + else: + a = cp.array(a, dtype=dtype, order='F', copy=None) + b = cp.array(b, dtype=dtype, order='F', + copy=(None if overwrite_b else True)) + + m, n = (b.size, 1) if b.ndim == 1 else b.shape - if dtype == 'f': - trsm = cublas.strsm - elif dtype == 'd': - trsm = cublas.dtrsm - elif dtype == 'F': - trsm = cublas.ctrsm - else: # dtype == 'D' - trsm = cublas.ztrsm + if trans == 'N': + trans = cublas.CUBLAS_OP_N + elif trans == 'T': + trans = cublas.CUBLAS_OP_T + elif trans == 'C': + trans = cublas.CUBLAS_OP_C + + cublas_handle = device.get_cublas_handle() one = np.array(1, dtype=dtype) if lower: @@ -106,27 +168,45 @@ def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, else: uplo = cublas.CUBLAS_FILL_MODE_UPPER - if trans == 'N': - trans = cublas.CUBLAS_OP_N - elif trans == 'T': - trans = cublas.CUBLAS_OP_T - elif trans == 'C': - trans = cublas.CUBLAS_OP_C - if unit_diagonal: diag = cublas.CUBLAS_DIAG_UNIT else: diag = cublas.CUBLAS_DIAG_NON_UNIT if side: - blas_side = cublas.CUBLAS_SIDE_RIGHT + side = cublas.CUBLAS_SIDE_RIGHT else: - blas_side = cublas.CUBLAS_SIDE_LEFT - - trsm( - cublas_handle, blas_side, uplo, - trans, diag, - m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) + side = cublas.CUBLAS_SIDE_LEFT + + if batch_count: + if dtype == 'f': + trsm = cublas.strsmBatched + elif dtype == 'd': + trsm = cublas.dtrsmBatched + elif dtype == 'F': + trsm = cublas.ctrsmBatched + else: # dtype == 'D' + trsm = cublas.ztrsmBatched + trsm( + cublas_handle, side, uplo, + trans, diag, + m, n, one.ctypes.data, a_array.data.ptr, m, + b_array.data.ptr, m, batch_count) + return b.transpose(0, 2, 1).reshape(b_shape) + else: + if dtype == 'f': + trsm = cublas.strsm + elif dtype == 'd': + trsm = cublas.dtrsm + elif dtype == 'F': + trsm = cublas.ctrsm + else: # dtype == 'D' + trsm = cublas.ztrsm + trsm( + cublas_handle, side, uplo, + trans, diag, + m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) + return b def solve_triangular_host(a, b, trans=0, lower=False, unit_diagonal=False, From b073708a49e2937ccdf10b7d2ad93605ac8c2e65 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:42:51 +0000 Subject: [PATCH 010/157] reomved serinv_solve_triangular --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 72a4c751..67d21328 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -82,10 +82,10 @@ def _pobtas( if trans == "N": # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = serinv_solve_triangular( + B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, side=0 + lower=True ) B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= serinv_matmul( From 0ec3284bffbe75cbd0f09b5aa3fe0767f5290f28 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:45:41 +0000 Subject: [PATCH 011/157] debug messages --- src/serinv/algs/pobtas.py | 4 ++-- src/serinv/utils/trsm.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 67d21328..72a4c751 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -82,10 +82,10 @@ def _pobtas( if trans == "N": # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( + B[i * diag_blocksize : (i + 1) * diag_blocksize] = serinv_solve_triangular( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True + lower=True, side=0 ) B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= serinv_matmul( diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index a2fac8d5..c0da9e20 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -21,11 +21,14 @@ def serinv_solve_triangular(a, b, trans=0, lower = False, unit_diagonal=False, For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept plus the side parameter which can either be 0 or 1 for left or right hand side """ + print("one") xp = _get_module_from_array(a) - + print("two") if xp == np: + print("three") return solve_triangular_host(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) elif xp == cp: + print("four") return solve_triangular_device(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) else: ModuleNotFoundError("Unknown Module") From 176577000454eaed0019d882934bdfd17eee0074 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:48:24 +0000 Subject: [PATCH 012/157] print xp --- src/serinv/utils/trsm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index c0da9e20..550540cd 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -23,6 +23,7 @@ def serinv_solve_triangular(a, b, trans=0, lower = False, unit_diagonal=False, """ print("one") xp = _get_module_from_array(a) + print(xp) print("two") if xp == np: print("three") From 3c56e0900aac05146e91a326ce9b5577a5591555 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:49:21 +0000 Subject: [PATCH 013/157] fixed module tuple --- src/serinv/utils/trsm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index 550540cd..6190075b 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -22,7 +22,7 @@ def serinv_solve_triangular(a, b, trans=0, lower = False, unit_diagonal=False, plus the side parameter which can either be 0 or 1 for left or right hand side """ print("one") - xp = _get_module_from_array(a) + xp, la = _get_module_from_array(a) print(xp) print("two") if xp == np: From bc1d67fa91e12d1c15dcc4f7e891610ec9b8be54 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:52:47 +0000 Subject: [PATCH 014/157] removed transpose --- src/serinv/utils/trsm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index 6190075b..2c218fd2 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -324,7 +324,7 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, trans_a=trans, diag=unit_diagonal, side=side) else: # transposed system is solved since trtrs expects Fortran ordering - x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, + x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=not lower, trans_a=not trans, diag=unit_diagonal, side=side) return x \ No newline at end of file From d756ab9d48aacbd4d7d4615c2f8d97c406325f91 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:59:10 +0000 Subject: [PATCH 015/157] print trsm func --- src/serinv/utils/trsm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index 2c218fd2..334e31bd 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -310,6 +310,7 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) trsm, = get_blas_funcs(('trsm',), (a1, b1)) + print(trsm) if a1.dtype.char in 'fd': dtype = a1.dtype @@ -324,7 +325,7 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, trans_a=trans, diag=unit_diagonal, side=side) else: # transposed system is solved since trtrs expects Fortran ordering - x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=not lower, + x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, trans_a=not trans, diag=unit_diagonal, side=side) return x \ No newline at end of file From cf79d8d8aa48e38293cf0f5f14ef644ab639bc46 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 09:11:45 +0000 Subject: [PATCH 016/157] changed print --- src/serinv/utils/trsm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index 334e31bd..ae310acb 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -21,10 +21,10 @@ def serinv_solve_triangular(a, b, trans=0, lower = False, unit_diagonal=False, For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept plus the side parameter which can either be 0 or 1 for left or right hand side """ - print("one") + print(a) xp, la = _get_module_from_array(a) print(xp) - print("two") + print(b) if xp == np: print("three") return solve_triangular_host(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) From 4554d592f2822aca14472564a6fa1681e92a8769 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 09:31:16 +0000 Subject: [PATCH 017/157] changed alpha to 1 --- src/serinv/utils/trsm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index ae310acb..64ed8578 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -318,7 +318,7 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, dtype = np.promote_types(a1.dtype.char, 'f') one = np.array(1, dtype=dtype) - alpha = one.ctypes.data + alpha = 1 if a1.flags.f_contiguous or trans == 2: x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=lower, From 5a84dea9619938e2598b8d3d976b8a98cba1a848 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 09:32:02 +0000 Subject: [PATCH 018/157] removed one --- src/serinv/utils/trsm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index 64ed8578..c8214a9b 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -317,7 +317,6 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, else: dtype = np.promote_types(a1.dtype.char, 'f') - one = np.array(1, dtype=dtype) alpha = 1 if a1.flags.f_contiguous or trans == 2: From a7d4c183896c1018927b1594fc6e611b14c983dc Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 10:27:31 +0000 Subject: [PATCH 019/157] changed matmul host to own implementation --- src/serinv/utils/matmul.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/serinv/utils/matmul.py b/src/serinv/utils/matmul.py index aa68ab31..7febd69e 100644 --- a/src/serinv/utils/matmul.py +++ b/src/serinv/utils/matmul.py @@ -18,7 +18,7 @@ def serinv_matmul (a, b): xp, la = _get_module_from_array(a) if xp == np: - return matmul(a, b) + return matmul_gemm_host(a, b) elif xp == cp: return gemm('N', 'N', a, b) else: @@ -125,11 +125,8 @@ def _solve_triangular(a1, b1, trans_a=0, trans_b=0, overwrite_c=0): else: dtype = np.promote_types(a1.dtype.char, 'f') - one = np.array(1, dtype=dtype) - zero =np.array(0, dtype=dtype) - alpha = one.ctypes.data - beta = zero.ctypes.data - + alpha = 1 + beta = 0 x = gemm(alpha, a1.T, b1.T, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) From 3e92f8960d21a55ad22b4b738c4fc4a1ac44faa6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 10:28:32 +0000 Subject: [PATCH 020/157] changed array ordering --- src/serinv/utils/matmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/utils/matmul.py b/src/serinv/utils/matmul.py index 7febd69e..e67b4fdb 100644 --- a/src/serinv/utils/matmul.py +++ b/src/serinv/utils/matmul.py @@ -128,7 +128,7 @@ def _solve_triangular(a1, b1, trans_a=0, trans_b=0, overwrite_c=0): alpha = 1 beta = 0 - x = gemm(alpha, a1.T, b1.T, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) + x = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) return x From 372ea0ec405dbf9862398d71268e238638a09882 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 10:29:38 +0000 Subject: [PATCH 021/157] changed name of matmul function --- src/serinv/utils/matmul.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/utils/matmul.py b/src/serinv/utils/matmul.py index e67b4fdb..b8ccb241 100644 --- a/src/serinv/utils/matmul.py +++ b/src/serinv/utils/matmul.py @@ -109,12 +109,12 @@ def matmul_gemm_host(a, b, trans_a=0, trans_b=0, overwrite_c=0, check_finite=Fal ).dtype return np.empty_like(b1, dtype=dt_nonempty) - x = _solve_triangular(a1, b1, trans_a, trans_b, overwrite_c) + x = _matmul_gemm(a1, b1, trans_a, trans_b, overwrite_c) return x # solve_triangular without the input validation -def _solve_triangular(a1, b1, trans_a=0, trans_b=0, overwrite_c=0): +def _matmul_gemm(a1, b1, trans_a=0, trans_b=0, overwrite_c=0): trans_a = {'N': 0, 'T': 1, 'C': 2}.get(trans_a, trans_a) trans_b = {'N': 0, 'T': 1, 'C': 2}.get(trans_b, trans_b) From cc5138e448edab9e2ccc452fe6c1465527f6ee1e Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 10:42:11 +0000 Subject: [PATCH 022/157] expose trans param for matmul --- src/serinv/utils/matmul.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/utils/matmul.py b/src/serinv/utils/matmul.py index b8ccb241..a2cc7d27 100644 --- a/src/serinv/utils/matmul.py +++ b/src/serinv/utils/matmul.py @@ -13,14 +13,14 @@ except (ImportError, ImportWarning, ModuleNotFoundError): pass -def serinv_matmul (a, b): +def serinv_matmul (a, b, trans_a = 'N', trans_b = 'N'): """Wrapper to call GeMM for host or device""" xp, la = _get_module_from_array(a) if xp == np: - return matmul_gemm_host(a, b) + return matmul_gemm_host(a, b, trans_a=trans_a, trans_b=trans_b) elif xp == cp: - return gemm('N', 'N', a, b) + return gemm(trans_a, trans_b, a, b) else: ModuleNotFoundError("Unknown Module") From 7c4f97bc02f1a00895c99a4a7e390ae79b56c88f Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 07:43:36 +0000 Subject: [PATCH 023/157] removed local functions from pobtas to prepare for renaming and moving --- src/serinv/algs/pobtas.py | 6 +++--- src/serinv/utils/matmul.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 72a4c751..ec7ef158 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -82,15 +82,15 @@ def _pobtas( if trans == "N": # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = serinv_solve_triangular( + B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], lower=True, side=0 ) - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= serinv_matmul( + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( L_lower_diagonal_blocks[i] - , B[i * diag_blocksize : (i + 1) * diag_blocksize] + @ B[i * diag_blocksize : (i + 1) * diag_blocksize] ) B[-arrow_blocksize:] -= ( diff --git a/src/serinv/utils/matmul.py b/src/serinv/utils/matmul.py index a2cc7d27..f57e401a 100644 --- a/src/serinv/utils/matmul.py +++ b/src/serinv/utils/matmul.py @@ -9,7 +9,7 @@ try: import cupy as cp - from cupy.cublas import gemm + from cupy.cublas import gemm as cp_gemm except (ImportError, ImportWarning, ModuleNotFoundError): pass @@ -20,7 +20,7 @@ def serinv_matmul (a, b, trans_a = 'N', trans_b = 'N'): if xp == np: return matmul_gemm_host(a, b, trans_a=trans_a, trans_b=trans_b) elif xp == cp: - return gemm(trans_a, trans_b, a, b) + return cp_gemm(trans_a, trans_b, a, b) else: ModuleNotFoundError("Unknown Module") From b56e59156a37ae91a84c79a40da13641b3b7ceb6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 07:54:34 +0000 Subject: [PATCH 024/157] changed solve_triangluar to trsm in pobtaf --- src/serinv/algs/pobtaf.py | 25 ++++++++++--------- src/serinv/algs/pobtas.py | 1 - src/serinv/block_primitive/__init__.py | 7 ++++++ .../matmul.py => block_primitive/gemm.py} | 2 +- src/serinv/{utils => block_primitive}/trsm.py | 2 +- src/serinv/utils/__init__.py | 5 +--- 6 files changed, 23 insertions(+), 19 deletions(-) create mode 100644 src/serinv/block_primitive/__init__.py rename src/serinv/{utils/matmul.py => block_primitive/gemm.py} (98%) rename src/serinv/{utils => block_primitive}/trsm.py (99%) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 05f58d3b..4a99ca19 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -7,6 +7,7 @@ _get_cholesky, ) +from serinv.block_primitive import trsm, gemm def pobtaf( A_diagonal_blocks: ArrayLike, @@ -118,7 +119,7 @@ def _pobtaf( # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :].conj().T, lower=True, @@ -129,7 +130,7 @@ def _pobtaf( # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], A_lower_arrow_blocks[i, :, :].conj().T, lower=True, @@ -164,7 +165,7 @@ def _pobtaf( # L_{ndb+1, ndb} = A_{ndb+1, ndb} @ L_{ndb, ndb}^{-T} L_lower_arrow_blocks[-1, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[-1, :, :], A_lower_arrow_blocks[-1, :, :].conj().T, lower=True, @@ -210,7 +211,7 @@ def _pobtaf_permuted( # Compute lower factors # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :].conj().T, lower=True, @@ -221,7 +222,7 @@ def _pobtaf_permuted( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} buffer[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], buffer[i, :, :].conj().T, lower=True, @@ -232,7 +233,7 @@ def _pobtaf_permuted( # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], A_lower_arrow_blocks[i, :, :].conj().T, lower=True, @@ -391,7 +392,7 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -418,7 +419,7 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -488,7 +489,7 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) if factorize_last_block: L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] = ( - cu_la.solve_triangular( + cu_trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T, lower=True, @@ -654,7 +655,7 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -681,7 +682,7 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -701,7 +702,7 @@ def _pobtaf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - cu_la.solve_triangular( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, lower=True, diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index ec7ef158..f4b825d2 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -7,7 +7,6 @@ ) -from serinv.utils import serinv_matmul, serinv_solve_triangular def pobtas( L_diagonal_blocks: ArrayLike, diff --git a/src/serinv/block_primitive/__init__.py b/src/serinv/block_primitive/__init__.py new file mode 100644 index 00000000..b8cf0541 --- /dev/null +++ b/src/serinv/block_primitive/__init__.py @@ -0,0 +1,7 @@ +from serinv.block_primitive.gemm import gemm +from serinv.block_primitive.trsm import trsm + +__all__ = [ + "gemm", + "trsm" +] \ No newline at end of file diff --git a/src/serinv/utils/matmul.py b/src/serinv/block_primitive/gemm.py similarity index 98% rename from src/serinv/utils/matmul.py rename to src/serinv/block_primitive/gemm.py index f57e401a..975a4b85 100644 --- a/src/serinv/utils/matmul.py +++ b/src/serinv/block_primitive/gemm.py @@ -13,7 +13,7 @@ except (ImportError, ImportWarning, ModuleNotFoundError): pass -def serinv_matmul (a, b, trans_a = 'N', trans_b = 'N'): +def gemm (a, b, trans_a = 'N', trans_b = 'N'): """Wrapper to call GeMM for host or device""" xp, la = _get_module_from_array(a) diff --git a/src/serinv/utils/trsm.py b/src/serinv/block_primitive/trsm.py similarity index 99% rename from src/serinv/utils/trsm.py rename to src/serinv/block_primitive/trsm.py index c8214a9b..1809289a 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/block_primitive/trsm.py @@ -14,7 +14,7 @@ except (ImportError, ImportWarning, ModuleNotFoundError): pass -def serinv_solve_triangular(a, b, trans=0, lower = False, unit_diagonal=False, +def trsm(a, b, trans=0, lower = False, unit_diagonal=False, overwrite_b=False, check_finite=False, side=0): """Wrapper for the trsm function to call depending on wheter the solve happens on the host or the device diff --git a/src/serinv/utils/__init__.py b/src/serinv/utils/__init__.py index 00a5a351..1c54d228 100644 --- a/src/serinv/utils/__init__.py +++ b/src/serinv/utils/__init__.py @@ -8,8 +8,7 @@ from serinv.utils.pobtx import allocate_pobtx_permutation_buffers from serinv.utils.pobtax import allocate_pobtax_permutation_buffers -from serinv.utils.matmul import serinv_matmul -from serinv.utils.trsm import serinv_solve_triangular + __all__ = [ "check_block_dd", @@ -18,6 +17,4 @@ "allocate_ddbtax_permutation_buffers", "allocate_pobtx_permutation_buffers", "allocate_pobtax_permutation_buffers", - "serinv_matmul", - "serinv_solve_triangluar" ] From bdea80fab4b23c4e5d845fecaab7fdf584f5a236 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 07:56:01 +0000 Subject: [PATCH 025/157] =?UTF-8?q?removed=20l=C3=B6eftover=20side=20param?= =?UTF-8?q?eter=20from=20pobtas?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index f4b825d2..dbc2d916 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -84,7 +84,7 @@ def _pobtas( B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, side=0 + lower=True ) B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( From 3bd5108f786b8853448f17bb1672c8677e07cd2f Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 07:57:04 +0000 Subject: [PATCH 026/157] removed double conjugate once for testing side --- src/serinv/algs/pobtaf.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 4a99ca19..7cbf38ff 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -121,11 +121,9 @@ def _pobtaf( L_lower_diagonal_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks[i, :, :], + lower=True, side = 1 ) - .conj() - .T ) # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} From 5f25de7bb03423374d1da0f68c52770555fcd7d6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:02:02 +0000 Subject: [PATCH 027/157] test tom check schape and size of trsm with side --- src/serinv/algs/pobtaf.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 7cbf38ff..8af9708e 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -117,13 +117,30 @@ def _pobtaf( # L_{i, i} = chol(A_{i, i}) L_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :]) + + ### + # Testing for shape and size + print(L_diagonal_blocks[i, :, :]) + print(A_lower_diagonal_blocks[i, :, :].conj().T) + print(A_lower_diagonal_blocks[i, :, :]) + + L_test = trsm( + L_diagonal_blocks[i, :, :], + A_lower_diagonal_blocks[i, :, :], + lower=True, side = 1 + ) + print(L_test) + ### + # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :], - lower=True, side = 1 + A_lower_diagonal_blocks[i, :, :].conj().T, + lower=True, ) + .conj() + .T ) # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} From 332e1ebbe3e0100c698aea1734bfe63902a4839e Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:06:12 +0000 Subject: [PATCH 028/157] forced error to test shapes --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 8af9708e..d458ade7 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -137,7 +137,7 @@ def _pobtaf( trsm( L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + lower=True, side = 1 ) .conj() .T From d5c59a41a0745144b1d673b88e13c336058031a4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:07:28 +0000 Subject: [PATCH 029/157] removed some debug messages --- src/serinv/block_primitive/trsm.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/serinv/block_primitive/trsm.py b/src/serinv/block_primitive/trsm.py index 1809289a..0ca1e0ca 100644 --- a/src/serinv/block_primitive/trsm.py +++ b/src/serinv/block_primitive/trsm.py @@ -21,15 +21,10 @@ def trsm(a, b, trans=0, lower = False, unit_diagonal=False, For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept plus the side parameter which can either be 0 or 1 for left or right hand side """ - print(a) xp, la = _get_module_from_array(a) - print(xp) - print(b) if xp == np: - print("three") return solve_triangular_host(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) elif xp == cp: - print("four") return solve_triangular_device(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) else: ModuleNotFoundError("Unknown Module") From 24a7bb0529ff76f26d49fd3f771bcb517fa1829f Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:12:00 +0000 Subject: [PATCH 030/157] more testing --- src/serinv/algs/pobtaf.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index d458ade7..cbcde138 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -120,16 +120,30 @@ def _pobtaf( ### # Testing for shape and size + print("###") print(L_diagonal_blocks[i, :, :]) + print("###") print(A_lower_diagonal_blocks[i, :, :].conj().T) + print("###") print(A_lower_diagonal_blocks[i, :, :]) - + print("###") L_test = trsm( L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :], lower=True, side = 1 ) print(L_test) + print("###") + L_test = ( + trsm( + L_diagonal_blocks[i, :, :], + A_lower_diagonal_blocks[i, :, :].conj().T, + lower=True, side = 0 + ) + .conj() + .T + ) + print(L_test) ### # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} From 11315008611041e481c91596569e8ceb6f0f0074 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:13:13 +0000 Subject: [PATCH 031/157] more debug --- src/serinv/algs/pobtaf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index cbcde138..d5eadefb 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -127,6 +127,7 @@ def _pobtaf( print("###") print(A_lower_diagonal_blocks[i, :, :]) print("###") + print("side = 1 sol") L_test = trsm( L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :], @@ -134,6 +135,7 @@ def _pobtaf( ) print(L_test) print("###") + print("side = 0 sol") L_test = ( trsm( L_diagonal_blocks[i, :, :], From fda845619c3041cb90261c60512e194daa26d303 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:14:18 +0000 Subject: [PATCH 032/157] swapped A and B in test --- src/serinv/algs/pobtaf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index d5eadefb..ee090233 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -129,8 +129,9 @@ def _pobtaf( print("###") print("side = 1 sol") L_test = trsm( - L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :], + L_diagonal_blocks[i, :, :], + lower=True, side = 1 ) print(L_test) From e1b57756f61d9e5a3c7dfb3478bf290c7c747a43 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:17:27 +0000 Subject: [PATCH 033/157] swapped back --- src/serinv/algs/pobtaf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index ee090233..d5eadefb 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -129,9 +129,8 @@ def _pobtaf( print("###") print("side = 1 sol") L_test = trsm( - A_lower_diagonal_blocks[i, :, :], L_diagonal_blocks[i, :, :], - + A_lower_diagonal_blocks[i, :, :], lower=True, side = 1 ) print(L_test) From df626e3718302144590d59d46de46824554a9fc8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:28:49 +0000 Subject: [PATCH 034/157] transpose L --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index d5eadefb..aaa2da25 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -129,7 +129,7 @@ def _pobtaf( print("###") print("side = 1 sol") L_test = trsm( - L_diagonal_blocks[i, :, :], + L_diagonal_blocks[i, :, :].conj.T, A_lower_diagonal_blocks[i, :, :], lower=True, side = 1 ) From 0d60090cf487c32ac4ef531b92074ab64f795b54 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:29:37 +0000 Subject: [PATCH 035/157] typo --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index aaa2da25..6ebb6cbb 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -129,7 +129,7 @@ def _pobtaf( print("###") print("side = 1 sol") L_test = trsm( - L_diagonal_blocks[i, :, :].conj.T, + L_diagonal_blocks[i, :, :].conj().T, A_lower_diagonal_blocks[i, :, :], lower=True, side = 1 ) From 142acdc89d332a76b2cc9b257054a7902869fcc4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:31:28 +0000 Subject: [PATCH 036/157] changed conj.T to the trans param --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 6ebb6cbb..7c004630 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -129,9 +129,9 @@ def _pobtaf( print("###") print("side = 1 sol") L_test = trsm( - L_diagonal_blocks[i, :, :].conj().T, + L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :], - lower=True, side = 1 + trans=2,lower=True, side = 1 ) print(L_test) print("###") From bfbdcf1d1df880be943628d4cb218ee21e4e6e52 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:32:49 +0000 Subject: [PATCH 037/157] actually implement frist trsm --- src/serinv/algs/pobtaf.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 7c004630..68d196ee 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -131,7 +131,7 @@ def _pobtaf( L_test = trsm( L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :], - trans=2,lower=True, side = 1 + trans='C',lower=True, side=1 ) print(L_test) print("###") @@ -152,11 +152,10 @@ def _pobtaf( L_lower_diagonal_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, side = 1 + A_lower_diagonal_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T + ) # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} From e80022c959e314edfe7f18569e5fc412839d5396 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:36:49 +0000 Subject: [PATCH 038/157] changed all trsm's --- src/serinv/algs/pobtaf.py | 103 +++++++++----------------------------- 1 file changed, 25 insertions(+), 78 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 68d196ee..0577b66c 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -117,37 +117,6 @@ def _pobtaf( # L_{i, i} = chol(A_{i, i}) L_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :]) - - ### - # Testing for shape and size - print("###") - print(L_diagonal_blocks[i, :, :]) - print("###") - print(A_lower_diagonal_blocks[i, :, :].conj().T) - print("###") - print(A_lower_diagonal_blocks[i, :, :]) - print("###") - print("side = 1 sol") - L_test = trsm( - L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :], - trans='C',lower=True, side=1 - ) - print(L_test) - print("###") - print("side = 0 sol") - L_test = ( - trsm( - L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, side = 0 - ) - .conj() - .T - ) - print(L_test) - ### - # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( trsm( @@ -162,11 +131,9 @@ def _pobtaf( L_lower_arrow_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i, :, :].conj().T, - lower=True, + A_lower_arrow_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # Update next diagonal block @@ -197,11 +164,9 @@ def _pobtaf( L_lower_arrow_blocks[-1, :, :] = ( trsm( L_diagonal_blocks[-1, :, :], - A_lower_arrow_blocks[-1, :, :].conj().T, - lower=True, + A_lower_arrow_blocks[-1, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} @@ -243,11 +208,9 @@ def _pobtaf_permuted( L_lower_diagonal_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # L_{top, i} = A_{top, i} @ U{i, i}^{-1} @@ -255,21 +218,17 @@ def _pobtaf_permuted( trsm( L_diagonal_blocks[i, :, :], buffer[i, :, :].conj().T, - lower=True, + trans='C',lower=True, side=1 ) - .conj() - .T ) # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i, :, :].conj().T, - lower=True, + A_lower_arrow_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # Update next diagonal block @@ -422,13 +381,11 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) compute_lower_events[i % 2].record(stream=compute_stream) @@ -449,13 +406,11 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_arrow_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_arrow_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) compute_arrow_events[i % 2].record(stream=compute_stream) @@ -519,13 +474,11 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) if factorize_last_block: L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], - A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T, - lower=True, + A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) compute_arrow_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) @@ -685,13 +638,11 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) cp_lower_events[i % 2].record(stream=compute_stream) @@ -712,13 +663,11 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, - lower=True, + trans='C',lower=True, side=1 ) - .conj() - .T ) cp_arrow_events[i % 2].record(stream=compute_stream) @@ -732,13 +681,11 @@ def _pobtaf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, - lower=True, + trans='C',lower=True, side=1 ) - .conj() - .T ) cp_upper_nested_dissection_buffer_events[i % 2].record( stream=compute_stream From 2a8291051c701dee4115a616a57229148d8da451 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:40:36 +0000 Subject: [PATCH 039/157] after the previous version faailed, this is the second attempt --- src/serinv/algs/pobtaf.py | 66 +++++++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 0577b66c..5184a19d 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -164,9 +164,11 @@ def _pobtaf( L_lower_arrow_blocks[-1, :, :] = ( trsm( L_diagonal_blocks[-1, :, :], - A_lower_arrow_blocks[-1, :, :], - trans='C',lower=True, side=1 + A_lower_arrow_blocks[-1, :, :].conj().T, + lower=True, ) + .conj() + .T ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} @@ -208,9 +210,11 @@ def _pobtaf_permuted( L_lower_diagonal_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :], - trans='C',lower=True, side=1 + A_lower_diagonal_blocks[i, :, :].conj().T, + lower=True, ) + .conj() + .T ) # L_{top, i} = A_{top, i} @ U{i, i}^{-1} @@ -218,17 +222,21 @@ def _pobtaf_permuted( trsm( L_diagonal_blocks[i, :, :], buffer[i, :, :].conj().T, - trans='C',lower=True, side=1 + lower=True, ) + .conj() + .T ) # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i, :, :], - trans='C',lower=True, side=1 + A_lower_arrow_blocks[i, :, :].conj().T, + lower=True, ) + .conj() + .T ) # Update next diagonal block @@ -381,11 +389,13 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - trsm( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :], - trans='C',lower=True, side=1 + A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, + lower=True, ) + .conj() + .T ) compute_lower_events[i % 2].record(stream=compute_stream) @@ -406,11 +416,13 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - trsm( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_arrow_blocks_d[i % 2, :, :], - trans='C',lower=True, side=1 + A_lower_arrow_blocks_d[i % 2, :, :].conj().T, + lower=True, ) + .conj() + .T ) compute_arrow_events[i % 2].record(stream=compute_stream) @@ -474,11 +486,13 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) if factorize_last_block: L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] = ( - trsm( + cu_trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], - A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], - trans='C',lower=True, side=1 + A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T, + lower=True, ) + .conj() + .T ) compute_arrow_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) @@ -638,11 +652,13 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - trsm( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :], - trans='C',lower=True, side=1 + A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, + lower=True, ) + .conj() + .T ) cp_lower_events[i % 2].record(stream=compute_stream) @@ -663,11 +679,13 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - trsm( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, - trans='C',lower=True, side=1 + lower=True, ) + .conj() + .T ) cp_arrow_events[i % 2].record(stream=compute_stream) @@ -681,11 +699,13 @@ def _pobtaf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - trsm( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, - trans='C',lower=True, side=1 + lower=True, ) + .conj() + .T ) cp_upper_nested_dissection_buffer_events[i % 2].record( stream=compute_stream From df87182a09a9a15bcacdc9096e9bac14c6370a34 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:49:27 +0000 Subject: [PATCH 040/157] removed side from arrow because of dim mismatch, added it to other arrow to check for dim mismatch --- src/serinv/algs/pobtaf.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 5184a19d..c2c01698 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -131,9 +131,11 @@ def _pobtaf( L_lower_arrow_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i, :, :], - trans='C',lower=True, side=1 + A_lower_arrow_blocks[i, :, :].conj().T, + lower=True, ) + .conj() + .T ) # Update next diagonal block @@ -164,11 +166,9 @@ def _pobtaf( L_lower_arrow_blocks[-1, :, :] = ( trsm( L_diagonal_blocks[-1, :, :], - A_lower_arrow_blocks[-1, :, :].conj().T, - lower=True, + A_lower_arrow_blocks[-1, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} From db39806384a2d72db626cf88782d7091dfcada73 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:52:03 +0000 Subject: [PATCH 041/157] implemented trsm side right for all non arrow solves --- src/serinv/algs/pobtaf.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index c2c01698..a85a6a03 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -166,9 +166,11 @@ def _pobtaf( L_lower_arrow_blocks[-1, :, :] = ( trsm( L_diagonal_blocks[-1, :, :], - A_lower_arrow_blocks[-1, :, :], - trans='C',lower=True, side=1 + A_lower_arrow_blocks[-1, :, :].conj().T, + lower=True, ) + .conj() + .T ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} @@ -210,11 +212,9 @@ def _pobtaf_permuted( L_lower_diagonal_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # L_{top, i} = A_{top, i} @ U{i, i}^{-1} @@ -389,13 +389,11 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) compute_lower_events[i % 2].record(stream=compute_stream) @@ -416,7 +414,7 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -486,7 +484,7 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) if factorize_last_block: L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T, lower=True, @@ -652,13 +650,11 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) cp_lower_events[i % 2].record(stream=compute_stream) @@ -679,7 +675,7 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -699,7 +695,7 @@ def _pobtaf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, lower=True, From 91b6d9bcb2dddf575269bcfb15a17a64e6518a1a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 09:21:36 +0000 Subject: [PATCH 042/157] imported cupy gemm to local --- src/serinv/block_primitive/gemm.py | 157 +++++++++++++++++++++++++++-- 1 file changed, 151 insertions(+), 6 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 975a4b85..78189ede 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -10,6 +10,9 @@ try: import cupy as cp from cupy.cublas import gemm as cp_gemm + from cupy_backends.cuda.libs import cublas + from cupy import _core + from cupy.cuda import device except (ImportError, ImportWarning, ModuleNotFoundError): pass @@ -20,12 +23,12 @@ def gemm (a, b, trans_a = 'N', trans_b = 'N'): if xp == np: return matmul_gemm_host(a, b, trans_a=trans_a, trans_b=trans_b) elif xp == cp: - return cp_gemm(trans_a, trans_b, a, b) + return matmul_gemm_device(trans_a, trans_b, a, b) else: ModuleNotFoundError("Unknown Module") -def matmul_gemm_host(a, b, trans_a=0, trans_b=0, overwrite_c=0, check_finite=False): +def matmul_gemm_host(a, b, alpha=1, beta=0, c=None, trans_a=0, trans_b=0, overwrite_c=0, check_finite=False): """ Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. @@ -95,12 +98,16 @@ def matmul_gemm_host(a, b, trans_a=0, trans_b=0, overwrite_c=0, check_finite=Fal a1 = _asarray_validated(a, check_finite=check_finite) b1 = _asarray_validated(b, check_finite=check_finite) + c1 = _asarray_validated(c, check_finite=check_finite) if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: raise ValueError('expected square matrix') if a1.shape[0] != b1.shape[0]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + if beta != 0 and c1 == None: + raise ValueError('expected C matrix') # accommodate empty arrays if b1.size == 0: @@ -109,12 +116,12 @@ def matmul_gemm_host(a, b, trans_a=0, trans_b=0, overwrite_c=0, check_finite=Fal ).dtype return np.empty_like(b1, dtype=dt_nonempty) - x = _matmul_gemm(a1, b1, trans_a, trans_b, overwrite_c) + x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) return x # solve_triangular without the input validation -def _matmul_gemm(a1, b1, trans_a=0, trans_b=0, overwrite_c=0): +def _matmul_gemm(a1, b1, alpha=1, beta=0, c1=None, trans_a=0, trans_b=0, overwrite_c=0): trans_a = {'N': 0, 'T': 1, 'C': 2}.get(trans_a, trans_a) trans_b = {'N': 0, 'T': 1, 'C': 2}.get(trans_b, trans_b) @@ -127,8 +134,146 @@ def _matmul_gemm(a1, b1, trans_a=0, trans_b=0, overwrite_c=0): alpha = 1 beta = 0 - - x = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) + if beta == 0: + x = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) + else: + x = gemm(alpha, a1, b1, beta, c1, trans_a, trans_b, overwrite_c) return x + + + +def _trans_to_cublas_op(trans): + if trans == 'N' or trans == cublas.CUBLAS_OP_N: + trans = cublas.CUBLAS_OP_N + elif trans == 'T' or trans == cublas.CUBLAS_OP_T: + trans = cublas.CUBLAS_OP_T + elif trans == 'H' or trans == cublas.CUBLAS_OP_C: + trans = cublas.CUBLAS_OP_C + else: + raise TypeError('invalid trans (actual: {})'.format(trans)) + return trans + +def _decide_ld_and_trans(a, trans): + ld = None + if trans in (cublas.CUBLAS_OP_N, cublas.CUBLAS_OP_T): + if a._f_contiguous: + ld = a.shape[0] + elif a._c_contiguous: + ld = a.shape[1] + trans = 1 - trans + return ld, trans + + +def _change_order_if_necessary(a, lda): + if lda is None: + lda = a.shape[0] + if not a._f_contiguous: + a = a.copy(order='F') + return a, lda + +def _get_scalar_ptr(a, dtype): + if isinstance(a, cp.ndarray): + if a.dtype != dtype: + a = cp.array(a, dtype=dtype) + a_ptr = a.data.ptr + else: + if not (isinstance(a, np.ndarray) and a.dtype == dtype): + a = np.array(a, dtype=dtype) + a_ptr = a.ctypes.data + return a, a_ptr + + +def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): + """Computes out = alpha * op(a) @ op(b) + beta * out + + op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', + op(a) = a.T.conj() if transa is 'H'. + op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', + op(b) = b.T.conj() if transb is 'H'. + """ + assert a.ndim == b.ndim == 2 + assert a.dtype == b.dtype + dtype = a.dtype.char + if dtype == 'f': + func = cublas.sgemm + elif dtype == 'd': + func = cublas.dgemm + elif dtype == 'F': + func = cublas.cgemm + elif dtype == 'D': + func = cublas.zgemm + else: + raise TypeError('invalid dtype') + + + transa = _trans_to_cublas_op(transa) + transb = _trans_to_cublas_op(transb) + if transa == cublas.CUBLAS_OP_N: + m, k = a.shape + else: + k, m = a.shape + if transb == cublas.CUBLAS_OP_N: + n = b.shape[1] + assert b.shape[0] == k + else: + n = b.shape[0] + assert b.shape[1] == k + if out is None: + out = cp.empty((m, n), dtype=dtype, order='F') + beta = 0.0 + else: + assert out.ndim == 2 + assert out.shape == (m, n) + assert out.dtype == dtype + + alpha, alpha_ptr = _get_scalar_ptr(alpha, a.dtype) + beta, beta_ptr = _get_scalar_ptr(beta, a.dtype) + handle = device.get_cublas_handle() + orig_mode = cublas.getPointerMode(handle) + if isinstance(alpha, cp.ndarray) or isinstance(beta, cp.ndarray): + if not isinstance(alpha, cp.ndarray): + alpha = cp.array(alpha) + alpha_ptr = alpha.data.ptr + if not isinstance(beta, cp.ndarray): + beta = cp.array(beta) + beta_ptr = beta.data.ptr + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_DEVICE) + else: + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_HOST) + + lda, transa = _decide_ld_and_trans(a, transa) + ldb, transb = _decide_ld_and_trans(b, transb) + if not (lda is None or ldb is None): + if out._f_contiguous: + try: + func(handle, transa, transb, m, n, k, alpha_ptr, + a.data.ptr, lda, b.data.ptr, ldb, beta_ptr, out.data.ptr, + m) + finally: + cublas.setPointerMode(handle, orig_mode) + return out + elif out._c_contiguous: + # Computes out.T = alpha * b.T @ a.T + beta * out.T + try: + func(handle, 1 - transb, 1 - transa, n, m, k, alpha_ptr, + b.data.ptr, ldb, a.data.ptr, lda, beta_ptr, out.data.ptr, + n) + finally: + cublas.setPointerMode(handle, orig_mode) + return out + + a, lda = _change_order_if_necessary(a, lda) + b, ldb = _change_order_if_necessary(b, ldb) + c = out + if not out._f_contiguous: + c = out.copy(order='F') + try: + func(handle, transa, transb, m, n, k, alpha_ptr, a.data.ptr, lda, + b.data.ptr, ldb, beta_ptr, c.data.ptr, m) + finally: + cublas.setPointerMode(handle, orig_mode) + if not out._f_contiguous: + _core.elementwise_copy(c, out) + return out \ No newline at end of file From d8c1d7a5a7ca6a8e74f334b82c3eb6b1f079b3fc Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 09:22:42 +0000 Subject: [PATCH 043/157] added error to test --- src/serinv/block_primitive/gemm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 78189ede..1f270211 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -193,6 +193,7 @@ def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', op(b) = b.T.conj() if transb is 'H'. """ + ValueError("TEST") assert a.ndim == b.ndim == 2 assert a.dtype == b.dtype dtype = a.dtype.char From 87302016a2a95a8166bcd748fdd59dc5cfc1e298 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 09:23:07 +0000 Subject: [PATCH 044/157] fixed error --- src/serinv/block_primitive/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 1f270211..ef3c39c1 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -193,7 +193,7 @@ def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', op(b) = b.T.conj() if transb is 'H'. """ - ValueError("TEST") + raise ValueError("TEST") assert a.ndim == b.ndim == 2 assert a.dtype == b.dtype dtype = a.dtype.char From e89414630ce0e20300c67bfbec32e46b0ce2d31a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 09:25:11 +0000 Subject: [PATCH 045/157] implemented one provsionary gemm in pobtaf --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index a85a6a03..267236f9 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -142,8 +142,8 @@ def _pobtaf( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( A_diagonal_blocks[i + 1, :, :] - - L_lower_diagonal_blocks[i, :, :] - @ L_lower_diagonal_blocks[i, :, :].conj().T + - gemm(L_lower_diagonal_blocks[i, :, :] + , L_lower_diagonal_blocks[i, :, :].conj().T) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T From d2c3cd403a427431c507aa68a3847a22eb713941 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 09:26:56 +0000 Subject: [PATCH 046/157] removed error for testing --- src/serinv/block_primitive/gemm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index ef3c39c1..78189ede 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -193,7 +193,6 @@ def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', op(b) = b.T.conj() if transb is 'H'. """ - raise ValueError("TEST") assert a.ndim == b.ndim == 2 assert a.dtype == b.dtype dtype = a.dtype.char From d251adb965cc4600e4152f772a375674878dd9b3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 09:28:13 +0000 Subject: [PATCH 047/157] fixed validating array if no array was present to begin with --- src/serinv/block_primitive/gemm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 78189ede..4e332686 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -98,7 +98,8 @@ def matmul_gemm_host(a, b, alpha=1, beta=0, c=None, trans_a=0, trans_b=0, overwr a1 = _asarray_validated(a, check_finite=check_finite) b1 = _asarray_validated(b, check_finite=check_finite) - c1 = _asarray_validated(c, check_finite=check_finite) + if c != None: + c1 = _asarray_validated(c, check_finite=check_finite) if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: raise ValueError('expected square matrix') From 58429da259eb1b7f8ccdc091ff0ef188736ad740 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 09:28:57 +0000 Subject: [PATCH 048/157] fixed c1 one not existing if c was none --- src/serinv/block_primitive/gemm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 4e332686..c7c331bb 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -100,6 +100,8 @@ def matmul_gemm_host(a, b, alpha=1, beta=0, c=None, trans_a=0, trans_b=0, overwr b1 = _asarray_validated(b, check_finite=check_finite) if c != None: c1 = _asarray_validated(c, check_finite=check_finite) + else: + c1 = None if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: raise ValueError('expected square matrix') From af86da81b9da30d29196a37bbd029f5a168da747 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:30:57 +0000 Subject: [PATCH 049/157] changed gemm to accomodate in place operations --- src/serinv/algs/pobtaf.py | 7 ++++--- src/serinv/block_primitive/gemm.py | 12 ++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 267236f9..bf8a129a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -141,9 +141,10 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - A_diagonal_blocks[i + 1, :, :] - - gemm(L_lower_diagonal_blocks[i, :, :] - , L_lower_diagonal_blocks[i, :, :].conj().T) + + gemm(L_lower_diagonal_blocks[i, :, :] + , L_lower_diagonal_blocks[i, :, :].conj().T, + A_diagonal_blocks[i + 1, :, :], -1.0, 1.0) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index c7c331bb..6d290e81 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -9,26 +9,25 @@ try: import cupy as cp - from cupy.cublas import gemm as cp_gemm from cupy_backends.cuda.libs import cublas from cupy import _core from cupy.cuda import device except (ImportError, ImportWarning, ModuleNotFoundError): pass -def gemm (a, b, trans_a = 'N', trans_b = 'N'): +def gemm (a, b, c=None, alpha=1.0, beta=0.0, trans_a ='N', trans_b ='N'): """Wrapper to call GeMM for host or device""" xp, la = _get_module_from_array(a) if xp == np: return matmul_gemm_host(a, b, trans_a=trans_a, trans_b=trans_b) elif xp == cp: - return matmul_gemm_device(trans_a, trans_b, a, b) + return matmul_gemm_device(trans_a, trans_b, a, b, c, alpha, beta) else: ModuleNotFoundError("Unknown Module") -def matmul_gemm_host(a, b, alpha=1, beta=0, c=None, trans_a=0, trans_b=0, overwrite_c=0, check_finite=False): +def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, overwrite_c=0, check_finite=False): """ Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. @@ -124,7 +123,7 @@ def matmul_gemm_host(a, b, alpha=1, beta=0, c=None, trans_a=0, trans_b=0, overwr # solve_triangular without the input validation -def _matmul_gemm(a1, b1, alpha=1, beta=0, c1=None, trans_a=0, trans_b=0, overwrite_c=0): +def _matmul_gemm(a1, b1, alpha=1.0, beta=0.0, c1=None, trans_a=0, trans_b=0, overwrite_c=0): trans_a = {'N': 0, 'T': 1, 'C': 2}.get(trans_a, trans_a) trans_b = {'N': 0, 'T': 1, 'C': 2}.get(trans_b, trans_b) @@ -146,7 +145,7 @@ def _matmul_gemm(a1, b1, alpha=1, beta=0, c1=None, trans_a=0, trans_b=0, overwri return x - +# Util functions for cupy gemm def _trans_to_cublas_op(trans): if trans == 'N' or trans == cublas.CUBLAS_OP_N: trans = cublas.CUBLAS_OP_N @@ -186,6 +185,7 @@ def _get_scalar_ptr(a, dtype): a = np.array(a, dtype=dtype) a_ptr = a.ctypes.data return a, a_ptr +# Util functions for cupy gemm end def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): From 575962169bdc0fcbf7e19fa623db7166eaffbcde Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:33:57 +0000 Subject: [PATCH 050/157] changed first gemm to trans_b = c --- src/serinv/algs/pobtaf.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index bf8a129a..993da40b 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -141,10 +141,11 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - - gemm(L_lower_diagonal_blocks[i, :, :] - , L_lower_diagonal_blocks[i, :, :].conj().T, - A_diagonal_blocks[i + 1, :, :], -1.0, 1.0) + A_diagonal_blocks[i + 1, :, :] + - gemm(L_lower_diagonal_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + trans_b='C' + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T From f3fb2b5f19d16da3cbebd81fd82bb177476fa511 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:35:20 +0000 Subject: [PATCH 051/157] fixed different trans name --- src/serinv/block_primitive/gemm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 6d290e81..a3ec4d77 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -151,7 +151,7 @@ def _trans_to_cublas_op(trans): trans = cublas.CUBLAS_OP_N elif trans == 'T' or trans == cublas.CUBLAS_OP_T: trans = cublas.CUBLAS_OP_T - elif trans == 'H' or trans == cublas.CUBLAS_OP_C: + elif trans == 'C' or trans == cublas.CUBLAS_OP_C: trans = cublas.CUBLAS_OP_C else: raise TypeError('invalid trans (actual: {})'.format(trans)) @@ -192,9 +192,9 @@ def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): """Computes out = alpha * op(a) @ op(b) + beta * out op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', - op(a) = a.T.conj() if transa is 'H'. + op(a) = a.T.conj() if transa is 'C'. op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', - op(b) = b.T.conj() if transb is 'H'. + op(b) = b.T.conj() if transb is 'C'. """ assert a.ndim == b.ndim == 2 assert a.dtype == b.dtype From a48f5687f6c8db4605e7f11ceccaa1b9d1533dab Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:36:19 +0000 Subject: [PATCH 052/157] used alpha param on first gemm --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 993da40b..08b2322a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -142,9 +142,9 @@ def _pobtaf( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( A_diagonal_blocks[i + 1, :, :] - - gemm(L_lower_diagonal_blocks[i, :, :], + + gemm(L_lower_diagonal_blocks[i, :, :], L_lower_diagonal_blocks[i, :, :], - trans_b='C' + trans_b='C', alpha=-1.0 ) ) From 69da4ae227006f02df583d8a254ce6a34d8ca89a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:51:23 +0000 Subject: [PATCH 053/157] removed alpha and beta hardcoding --- src/serinv/block_primitive/gemm.py | 2 -- src/serinv/block_primitive/trsm.py | 1 - 2 files changed, 3 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index a3ec4d77..44673155 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -134,8 +134,6 @@ def _matmul_gemm(a1, b1, alpha=1.0, beta=0.0, c1=None, trans_a=0, trans_b=0, ove else: dtype = np.promote_types(a1.dtype.char, 'f') - alpha = 1 - beta = 0 if beta == 0: x = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) else: diff --git a/src/serinv/block_primitive/trsm.py b/src/serinv/block_primitive/trsm.py index 0ca1e0ca..5f447e27 100644 --- a/src/serinv/block_primitive/trsm.py +++ b/src/serinv/block_primitive/trsm.py @@ -305,7 +305,6 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) trsm, = get_blas_funcs(('trsm',), (a1, b1)) - print(trsm) if a1.dtype.char in 'fd': dtype = a1.dtype From 4e6a81ae6f54bcc362da19ecde10db7afd979909 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:53:47 +0000 Subject: [PATCH 054/157] changed to minus --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 08b2322a..bf8f9eee 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -142,7 +142,7 @@ def _pobtaf( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( A_diagonal_blocks[i + 1, :, :] - + gemm(L_lower_diagonal_blocks[i, :, :], + - gemm(L_lower_diagonal_blocks[i, :, :], L_lower_diagonal_blocks[i, :, :], trans_b='C', alpha=-1.0 ) From f148ab3c1670a7b8a6ebb9c1800a345b258b9d27 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:55:54 +0000 Subject: [PATCH 055/157] inserted some debug messages --- src/serinv/block_primitive/gemm.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 44673155..bff8631c 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -117,7 +117,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) ).dtype return np.empty_like(b1, dtype=dt_nonempty) - + print(alpha) x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) return x @@ -134,6 +134,12 @@ def _matmul_gemm(a1, b1, alpha=1.0, beta=0.0, c1=None, trans_a=0, trans_b=0, ove else: dtype = np.promote_types(a1.dtype.char, 'f') + print(a1) + print("###") + print(b1) + print("###") + print(alpha) + print("###") if beta == 0: x = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) else: From c3de7469b639f47591134b85bd7061fc1c0ccd34 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:56:20 +0000 Subject: [PATCH 056/157] reverted minus --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index bf8f9eee..08b2322a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -142,7 +142,7 @@ def _pobtaf( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( A_diagonal_blocks[i + 1, :, :] - - gemm(L_lower_diagonal_blocks[i, :, :], + + gemm(L_lower_diagonal_blocks[i, :, :], L_lower_diagonal_blocks[i, :, :], trans_b='C', alpha=-1.0 ) From dac04895057f2ebae88a44b86abae9bcb437b294 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:58:17 +0000 Subject: [PATCH 057/157] exposed alpha, beta and c for host gemm --- src/serinv/block_primitive/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index bff8631c..13ab9ac1 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -20,7 +20,7 @@ def gemm (a, b, c=None, alpha=1.0, beta=0.0, trans_a ='N', trans_b ='N'): xp, la = _get_module_from_array(a) if xp == np: - return matmul_gemm_host(a, b, trans_a=trans_a, trans_b=trans_b) + return matmul_gemm_host(a, b, c, alpha, beta, trans_a, trans_b) elif xp == cp: return matmul_gemm_device(trans_a, trans_b, a, b, c, alpha, beta) else: From cada2cb8c3d722b0ad8419cdc89534d39927874a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:08:15 +0000 Subject: [PATCH 058/157] convert alpha to complex for cgemm and zgemm host --- src/serinv/block_primitive/gemm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 13ab9ac1..70144426 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -209,8 +209,10 @@ def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): func = cublas.dgemm elif dtype == 'F': func = cublas.cgemm + alpha = complex(alpha) elif dtype == 'D': func = cublas.zgemm + alpha = complex(alpha) else: raise TypeError('invalid dtype') From 13e2b79e3d7aa30e6a6282733b7226004e92054f Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:11:20 +0000 Subject: [PATCH 059/157] inser dytpe debug --- src/serinv/block_primitive/gemm.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 70144426..e5eaea73 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -117,7 +117,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) ).dtype return np.empty_like(b1, dtype=dt_nonempty) - print(alpha) + print(alpha.dtype) x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) return x @@ -134,12 +134,6 @@ def _matmul_gemm(a1, b1, alpha=1.0, beta=0.0, c1=None, trans_a=0, trans_b=0, ove else: dtype = np.promote_types(a1.dtype.char, 'f') - print(a1) - print("###") - print(b1) - print("###") - print(alpha) - print("###") if beta == 0: x = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) else: @@ -209,10 +203,8 @@ def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): func = cublas.dgemm elif dtype == 'F': func = cublas.cgemm - alpha = complex(alpha) elif dtype == 'D': func = cublas.zgemm - alpha = complex(alpha) else: raise TypeError('invalid dtype') From 9d85d220ea29fa36bc2590934ab596b0f2934a19 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:13:04 +0000 Subject: [PATCH 060/157] changed type debug --- src/serinv/block_primitive/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index e5eaea73..81f03e7d 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -117,7 +117,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) ).dtype return np.empty_like(b1, dtype=dt_nonempty) - print(alpha.dtype) + print(alpha.type()) x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) return x From ce4f0bcf8dd70f4557c6f4d4117c79b399ddae92 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:13:49 +0000 Subject: [PATCH 061/157] changed type debug again --- src/serinv/block_primitive/gemm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 81f03e7d..87cfc42c 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -117,6 +117,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) ).dtype return np.empty_like(b1, dtype=dt_nonempty) + print(alpha) print(alpha.type()) x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) return x From fc12dc1e0305f7692ee8d208d26f44bbc565ef91 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:14:55 +0000 Subject: [PATCH 062/157] swapped order in function call --- src/serinv/block_primitive/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 87cfc42c..baa2add4 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -20,7 +20,7 @@ def gemm (a, b, c=None, alpha=1.0, beta=0.0, trans_a ='N', trans_b ='N'): xp, la = _get_module_from_array(a) if xp == np: - return matmul_gemm_host(a, b, c, alpha, beta, trans_a, trans_b) + return matmul_gemm_host(a, b, alpha, beta, c, trans_a, trans_b) elif xp == cp: return matmul_gemm_device(trans_a, trans_b, a, b, c, alpha, beta) else: From 7887bb0aa9b4cfd3b1fbce285c32638ccb7dd0eb Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:15:37 +0000 Subject: [PATCH 063/157] removed debug --- src/serinv/block_primitive/gemm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index baa2add4..e0232472 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -117,8 +117,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) ).dtype return np.empty_like(b1, dtype=dt_nonempty) - print(alpha) - print(alpha.type()) + x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) return x From 1ae08d4ad8bc83ee35af3d03a917bd0942cb2a67 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:22:00 +0000 Subject: [PATCH 064/157] fully use gemm at first location --- src/serinv/algs/pobtaf.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 08b2322a..036d115c 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -141,10 +141,12 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - A_diagonal_blocks[i + 1, :, :] - + gemm(L_lower_diagonal_blocks[i, :, :], - L_lower_diagonal_blocks[i, :, :], - trans_b='C', alpha=-1.0 + + gemm( + L_lower_diagonal_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_diagonal_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 ) ) From cfd7bd7e91aabd43a9b84dea609cbe00aa29e87a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:23:14 +0000 Subject: [PATCH 065/157] changed check for existing c --- src/serinv/block_primitive/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index e0232472..29eee7fe 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -97,7 +97,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov a1 = _asarray_validated(a, check_finite=check_finite) b1 = _asarray_validated(b, check_finite=check_finite) - if c != None: + if c not None: c1 = _asarray_validated(c, check_finite=check_finite) else: c1 = None From 993c17f5044c992452910f7afc58c0624261387c Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:24:28 +0000 Subject: [PATCH 066/157] fixed c not being able to be true --- src/serinv/block_primitive/gemm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 29eee7fe..61926c1a 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -97,10 +97,10 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov a1 = _asarray_validated(a, check_finite=check_finite) b1 = _asarray_validated(b, check_finite=check_finite) - if c not None: - c1 = _asarray_validated(c, check_finite=check_finite) - else: + if c == None: c1 = None + else: + c1 = _asarray_validated(c, check_finite=check_finite) if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: raise ValueError('expected square matrix') From c3aa5050223382595080ae3627e17ef00854a05f Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:26:02 +0000 Subject: [PATCH 067/157] fixed c again --- src/serinv/block_primitive/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 61926c1a..cce19df2 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -97,7 +97,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov a1 = _asarray_validated(a, check_finite=check_finite) b1 = _asarray_validated(b, check_finite=check_finite) - if c == None: + if c is None: c1 = None else: c1 = _asarray_validated(c, check_finite=check_finite) From a38c39d3e2cbf4e801812703f009cf144cce4d9a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:26:55 +0000 Subject: [PATCH 068/157] further c fix --- src/serinv/block_primitive/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index cce19df2..4fc20b87 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -108,7 +108,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov if a1.shape[0] != b1.shape[0]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') - if beta != 0 and c1 == None: + if beta != 0 and c1 is None: raise ValueError('expected C matrix') # accommodate empty arrays From a873567e2faf095686b7dc06bf3cd498c971d317 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:28:55 +0000 Subject: [PATCH 069/157] second gemm --- src/serinv/algs/pobtaf.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 036d115c..7e81ae41 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -141,7 +141,6 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - gemm( L_lower_diagonal_blocks[i, :, :], L_lower_diagonal_blocks[i, :, :], @@ -152,8 +151,12 @@ def _pobtaf( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - A_lower_arrow_blocks[i + 1, :, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_lower_arrow_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T From 73eeabecd15408478a7ee4bf7654f0802e9149d8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:36:13 +0000 Subject: [PATCH 070/157] removed square matrix check in gemm that was leftover from trsm --- src/serinv/block_primitive/gemm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 4fc20b87..e9d2d668 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -102,8 +102,6 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov else: c1 = _asarray_validated(c, check_finite=check_finite) - if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: - raise ValueError('expected square matrix') if a1.shape[0] != b1.shape[0]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') From 899d7c94a96af4d916bb77560ad5e23c70bedb55 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:39:46 +0000 Subject: [PATCH 071/157] changed input validation --- src/serinv/block_primitive/gemm.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index e9d2d668..052cb6ae 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -102,9 +102,21 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov else: c1 = _asarray_validated(c, check_finite=check_finite) - - if a1.shape[0] != b1.shape[0]: - raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + if not trans_a and not trans_b: + if a1.shape[0] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + elif trans_a and not trans_b: + if a1.shape[1] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + elif not trans_a and trans_b: + if a1.shape[0] != b1.shape[1]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + else: + if a1.shape[1] != b1.shape[1]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') if beta != 0 and c1 is None: raise ValueError('expected C matrix') From e02b484c19fdb5ee3602b62bf42029f67e0512c2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:41:17 +0000 Subject: [PATCH 072/157] third gemm --- src/serinv/algs/pobtaf.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 7e81ae41..57576074 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -161,8 +161,12 @@ def _pobtaf( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block[:, :] = ( - A_arrow_tip_block[:, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_arrow_blocks[i, :, :], + A_arrow_tip_block[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) if factorize_last_block: From b3b8c2218ee7be6f913ac1958db86906d02eb878 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:43:04 +0000 Subject: [PATCH 073/157] full normal pobtaf gemm implemented --- src/serinv/algs/pobtaf.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 57576074..4c3d85b7 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -186,8 +186,12 @@ def _pobtaf( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block[:, :] = ( - A_arrow_tip_block[:, :] - - L_lower_arrow_blocks[-1, :, :] @ L_lower_arrow_blocks[-1, :, :].conj().T + gemm( + L_lower_arrow_blocks[-1, :, :], + L_lower_arrow_blocks[-1, :, :], + A_arrow_tip_block[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) From 0bf8f588088988d3ad3fb455df5f17b41e3df8de Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:50:39 +0000 Subject: [PATCH 074/157] gemm in permuted pobtaf --- src/serinv/algs/pobtaf.py | 46 +++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 4c3d85b7..3631a522 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -257,40 +257,64 @@ def _pobtaf_permuted( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - A_diagonal_blocks[i + 1, :, :] - - L_lower_diagonal_blocks[i, :, :] - @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_diagonal_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_diagonal_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - A_lower_arrow_blocks[i + 1, :, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_lower_arrow_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T L_arrow_tip_block[:, :] = ( - L_arrow_tip_block[:, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_arrow_blocks[i, :, :], + L_arrow_tip_block[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - A_diagonal_blocks[0, :, :] - buffer[i, :, :] @ buffer[i, :, :].conj().T + gemm( + buffer[i, :, :], + buffer[i, :, :], + A_diagonal_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer[i + 1, :, :] = ( - -buffer[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + buffer[i, :, :], + L_lower_diagonal_blocks[i, :, :], + trans_b='C', alpha=-1.0 + ) ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_lower_arrow_blocks[0, :, :] = ( - A_lower_arrow_blocks[0, :, :] - - L_lower_arrow_blocks[i, :, :] @ buffer[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + buffer[i, :, :].conj().T, + A_lower_arrow_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) From 50312bb6780faf9fb8dacd8eb7e5ab3c7adac4ba Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:52:28 +0000 Subject: [PATCH 075/157] rollback to just one gemm --- src/serinv/algs/pobtaf.py | 41 ++++++++++----------------------------- 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 3631a522..858e7edb 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -258,8 +258,8 @@ def _pobtaf_permuted( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( gemm( - L_lower_diagonal_blocks[i, :, :], - L_lower_diagonal_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :] + @ L_lower_diagonal_blocks[i, :, :].conj().T, A_diagonal_blocks[i + 1, :, :], trans_b='C', alpha=-1.0, beta=1.0 ) @@ -267,54 +267,33 @@ def _pobtaf_permuted( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - gemm( - L_lower_arrow_blocks[i, :, :], - L_lower_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i + 1, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_lower_arrow_blocks[i + 1, :, :] + - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T ) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T L_arrow_tip_block[:, :] = ( - gemm( - L_lower_arrow_blocks[i, :, :], - L_lower_arrow_blocks[i, :, :], - L_arrow_tip_block[:, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + L_arrow_tip_block[:, :] + - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T ) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - gemm( - buffer[i, :, :], - buffer[i, :, :], - A_diagonal_blocks[0, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_diagonal_blocks[0, :, :] - buffer[i, :, :] @ buffer[i, :, :].conj().T ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer[i + 1, :, :] = ( - gemm( - buffer[i, :, :], - L_lower_diagonal_blocks[i, :, :], - trans_b='C', alpha=-1.0 - ) + -buffer[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_lower_arrow_blocks[0, :, :] = ( - gemm( - L_lower_arrow_blocks[i, :, :], - buffer[i, :, :].conj().T, - A_lower_arrow_blocks[0, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_lower_arrow_blocks[0, :, :] + - L_lower_arrow_blocks[i, :, :] @ buffer[i, :, :].conj().T ) From 2cf395aa6a3762e0c89bd34302af9a5945bb336a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:53:15 +0000 Subject: [PATCH 076/157] removed leftover conj t --- src/serinv/algs/pobtaf.py | 43 +++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 858e7edb..c430c245 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -247,7 +247,7 @@ def _pobtaf_permuted( L_lower_arrow_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i, :, :].conj().T, + A_lower_arrow_blocks[i, :, :], lower=True, ) .conj() @@ -258,8 +258,8 @@ def _pobtaf_permuted( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( gemm( - L_lower_diagonal_blocks[i, :, :] - @ L_lower_diagonal_blocks[i, :, :].conj().T, + L_lower_diagonal_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], A_diagonal_blocks[i + 1, :, :], trans_b='C', alpha=-1.0, beta=1.0 ) @@ -267,33 +267,54 @@ def _pobtaf_permuted( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - A_lower_arrow_blocks[i + 1, :, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_lower_arrow_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T L_arrow_tip_block[:, :] = ( - L_arrow_tip_block[:, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_arrow_blocks[i, :, :], + L_arrow_tip_block[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - A_diagonal_blocks[0, :, :] - buffer[i, :, :] @ buffer[i, :, :].conj().T + gemm( + buffer[i, :, :], + buffer[i, :, :], + A_diagonal_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer[i + 1, :, :] = ( - -buffer[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + buffer[i, :, :], + L_lower_diagonal_blocks[i, :, :], + trans_b='C', alpha=-1.0 + ) ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_lower_arrow_blocks[0, :, :] = ( - A_lower_arrow_blocks[0, :, :] - - L_lower_arrow_blocks[i, :, :] @ buffer[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + buffer[i, :, :].conj().T, + A_lower_arrow_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) From 3b3b244984e3f40356dc2a3cd2a9910e16ae60e2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:47:11 +0000 Subject: [PATCH 077/157] rollback to 1 gemm in permuted --- src/serinv/algs/pobtaf.py | 39 +++++++++------------------------------ 1 file changed, 9 insertions(+), 30 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index c430c245..efa94b0d 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -247,7 +247,7 @@ def _pobtaf_permuted( L_lower_arrow_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i, :, :], + A_lower_arrow_blocks[i, :, :].conj().T, lower=True, ) .conj() @@ -267,54 +267,33 @@ def _pobtaf_permuted( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - gemm( - L_lower_arrow_blocks[i, :, :], - L_lower_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i + 1, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_lower_arrow_blocks[i + 1, :, :] + - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T ) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T L_arrow_tip_block[:, :] = ( - gemm( - L_lower_arrow_blocks[i, :, :], - L_lower_arrow_blocks[i, :, :], - L_arrow_tip_block[:, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + L_arrow_tip_block[:, :] + - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T ) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - gemm( - buffer[i, :, :], - buffer[i, :, :], - A_diagonal_blocks[0, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_diagonal_blocks[0, :, :] - buffer[i, :, :] @ buffer[i, :, :].conj().T ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer[i + 1, :, :] = ( - gemm( - buffer[i, :, :], - L_lower_diagonal_blocks[i, :, :], - trans_b='C', alpha=-1.0 - ) + -buffer[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_lower_arrow_blocks[0, :, :] = ( - gemm( - L_lower_arrow_blocks[i, :, :], - buffer[i, :, :].conj().T, - A_lower_arrow_blocks[0, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_lower_arrow_blocks[0, :, :] + - L_lower_arrow_blocks[i, :, :] @ buffer[i, :, :].conj().T ) From 88b3a4cfdf6542530823630db327626546dfe085 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:49:06 +0000 Subject: [PATCH 078/157] next gemm in permuted --- src/serinv/algs/pobtaf.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index efa94b0d..43fde541 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -267,8 +267,12 @@ def _pobtaf_permuted( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - A_lower_arrow_blocks[i + 1, :, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_lower_arrow_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update the block at the tip of the arrowhead From 83f7247bb3e09cd0ef7aa2922997df5ff0db9d29 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:50:14 +0000 Subject: [PATCH 079/157] another gemm --- src/serinv/algs/pobtaf.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 43fde541..a0f1d720 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -278,8 +278,12 @@ def _pobtaf_permuted( # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T L_arrow_tip_block[:, :] = ( - L_arrow_tip_block[:, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_arrow_blocks[i, :, :], + L_arrow_tip_block[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update top and next upper/lower blocks of 2-sided factorization pattern From 7cae093a90d3ef331be7759fd3d8b020cb822b1d Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:51:25 +0000 Subject: [PATCH 080/157] next gemm --- src/serinv/algs/pobtaf.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index a0f1d720..56347f21 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -289,7 +289,12 @@ def _pobtaf_permuted( # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - A_diagonal_blocks[0, :, :] - buffer[i, :, :] @ buffer[i, :, :].conj().T + gemm( + buffer[i, :, :], + buffer[i, :, :], + A_diagonal_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T From 3f889e1220d469292f1de3ab1dd1d7e4da5acc9c Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:52:30 +0000 Subject: [PATCH 081/157] smaller gemm --- src/serinv/algs/pobtaf.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 56347f21..3b41a890 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -299,7 +299,12 @@ def _pobtaf_permuted( # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer[i + 1, :, :] = ( - -buffer[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + buffer[i, :, :], + L_lower_diagonal_blocks[i, :, :], + trans_b='C', alpha=-1.0 + ) + ) # Update the top (first blocks) of the arrowhead From 3c85f7bace8c6fea01ba3e534ddd19263242cb03 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:53:48 +0000 Subject: [PATCH 082/157] last permuted gemm --- src/serinv/algs/pobtaf.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 3b41a890..b00ffb36 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -303,15 +303,18 @@ def _pobtaf_permuted( buffer[i, :, :], L_lower_diagonal_blocks[i, :, :], trans_b='C', alpha=-1.0 - ) - + ) ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_lower_arrow_blocks[0, :, :] = ( - A_lower_arrow_blocks[0, :, :] - - L_lower_arrow_blocks[i, :, :] @ buffer[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + buffer[i, :, :], + A_lower_arrow_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) From ad936cbfcb64927243f909761ca412a6221d7731 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:57:38 +0000 Subject: [PATCH 083/157] first gemm in streaming --- src/serinv/algs/pobtaf.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b00ffb36..81c24a5a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -481,9 +481,12 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T From 9c4f49628ad7b2cc30ad989ed861446c6ac8a185 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:58:43 +0000 Subject: [PATCH 084/157] second gemm streaming --- src/serinv/algs/pobtaf.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 81c24a5a..3afb3a60 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -491,9 +491,12 @@ def _pobtaf_streaming( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - A_lower_arrow_blocks_d[(i + 1) % 2, :, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_lower_arrow_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) compute_lower_h2d_events[i % 2].record(stream=compute_stream) From e0e5bdd040382e5f3834090fec730e661cf69991 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:59:51 +0000 Subject: [PATCH 085/157] third gemm streaming --- src/serinv/algs/pobtaf.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 3afb3a60..8d29b1a7 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -502,9 +502,12 @@ def _pobtaf_streaming( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_arrow_blocks_d[i % 2, :, :], + A_arrow_tip_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) compute_arrow_h2d_events[i % 2].record(stream=compute_stream) From 4bb8e6ef0eba8d758abe8437013b0d4bb7e82549 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 13:04:48 +0000 Subject: [PATCH 086/157] two permuted streaming gemms --- src/serinv/algs/pobtaf.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 8d29b1a7..83104526 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -553,9 +553,12 @@ def _pobtaf_streaming( if factorize_last_block: # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] - @ L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], + A_arrow_tip_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) @@ -774,16 +777,22 @@ def _pobtaf_permuted_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - A_lower_arrow_blocks_d[(i + 1) % 2, :, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_lower_arrow_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T From 480d982b5ab2138a4dc286c22c1b7e6ac59ac48e Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 13:08:47 +0000 Subject: [PATCH 087/157] implemented gemms for permuted streaming --- src/serinv/algs/pobtaf.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 83104526..0975d58a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -797,34 +797,46 @@ def _pobtaf_permuted_streaming( # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer_d[(i + 1) % 2, :, :] = ( - -L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + trans_b='C', alpha=-1.0 + ) ) cp_lower_events_h2d_release[i % 2].record(stream=compute_stream) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_arrow_blocks_d[i % 2, :, :], + A_arrow_tip_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_arrow_bottom_top_block_d[:, :] = ( - A_arrow_bottom_top_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_upper_nested_dissection_buffer_d[i % 2, :, :], + A_arrow_bottom_top_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) cp_arrow_events_h2d_release[i % 2].record(stream=compute_stream) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_top_block_d[:, :] = ( - A_diagonal_top_block_d[:, :] - - L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T + gemm( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + L_upper_nested_dissection_buffer_d[i % 2, :, :], + A_diagonal_top_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # --- Device 2 Host transfers --- From bfadd081db9e57572abe9d19ef105d1d234d5c82 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 08:30:07 +0000 Subject: [PATCH 088/157] implemented a form of syrk/herk and added a error for testing --- src/serinv/algs/pobtaf.py | 3 +- src/serinv/block_primitive/gemm.py | 85 ++++------------------------ src/serinv/block_primitive/syherk.py | 71 +++++++++++++++++++++++ src/serinv/block_primitive/trsm.py | 9 +-- 4 files changed, 87 insertions(+), 81 deletions(-) create mode 100644 src/serinv/block_primitive/syherk.py diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 0975d58a..b81680b3 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -148,7 +148,8 @@ def _pobtaf( trans_b='C', alpha=-1.0, beta=1.0 ) ) - + print(A_diagonal_blocks[i + 1, :, :]) + raise ValueError("TEST") # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( gemm( diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 052cb6ae..8924cb2b 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -28,71 +28,12 @@ def gemm (a, b, c=None, alpha=1.0, beta=0.0, trans_a ='N', trans_b ='N'): def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, overwrite_c=0, check_finite=False): - """ - Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. - - Parameters - ---------- - a : (M, M) array_like - A triangular matrix - b : (M,) or (M, N) array_like - Right-hand side matrix in ``a x = b`` - lower : bool, optional - Use only data contained in the lower triangle of `a`. - Default is to use upper triangle. - trans : {0, 1, 2, 'N', 'T', 'C'}, optional - Type of system to solve: - - ======== ========= - trans system - ======== ========= - 0 or 'N' a x = b - 1 or 'T' a^T x = b - 2 or 'C' a^H x = b - ======== ========= - unit_diagonal : bool, optional - If True, diagonal elements of `a` are assumed to be 1 and - will not be referenced. - overwrite_b : bool, optional - Allow overwriting data in `b` (may enhance performance) - check_finite : bool, optional - Whether to check that the input matrices contain only finite numbers. - Disabling may give a performance gain, but may result in problems - (crashes, non-termination) if the inputs do contain infinities or NaNs. - - Returns - ------- - x : (M,) or (M, N) ndarray - Solution to the system ``a x = b``. Shape of return matches `b`. - - Raises - ------ - LinAlgError - If `a` is singular - - Notes - ----- - .. versionadded:: 0.9.0 - - Examples - -------- - Solve the lower triangular system a x = b, where:: - - [3 0 0 0] [4] - a = [2 1 0 0] b = [2] - [1 0 1 0] [4] - [1 1 1 1] [2] - - >>> import numpy as np - >>> from scipy.linalg import solve_triangular - >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]]) - >>> b = np.array([4, 2, 4, 2]) - >>> x = solve_triangular(a, b, lower=True) - >>> x - array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333]) - >>> a.dot(x) # Check the result - array([ 4., 2., 4., 2.]) + """Computes out = alpha * op(a) @ op(b) + beta * out + op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', + op(a) = a.T.conj() if transa is 'C'. + op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', + op(b) = b.T.conj() if transb is 'C'. """ a1 = _asarray_validated(a, check_finite=check_finite) @@ -128,29 +69,27 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov ).dtype return np.empty_like(b1, dtype=dt_nonempty) + if c1 is not None: + overwrite_c = overwrite_c or _datacopied(c1, c) + x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) return x -# solve_triangular without the input validation +# gemm without the input validation def _matmul_gemm(a1, b1, alpha=1.0, beta=0.0, c1=None, trans_a=0, trans_b=0, overwrite_c=0): trans_a = {'N': 0, 'T': 1, 'C': 2}.get(trans_a, trans_a) trans_b = {'N': 0, 'T': 1, 'C': 2}.get(trans_b, trans_b) gemm, = get_blas_funcs(('gemm',), (a1, b1)) - if a1.dtype.char in 'fd': - dtype = a1.dtype - else: - dtype = np.promote_types(a1.dtype.char, 'f') - if beta == 0: - x = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) + out = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) else: - x = gemm(alpha, a1, b1, beta, c1, trans_a, trans_b, overwrite_c) + out = gemm(alpha, a1, b1, beta, c1, trans_a, trans_b, overwrite_c) - return x + return out # Util functions for cupy gemm diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py new file mode 100644 index 00000000..d708ff11 --- /dev/null +++ b/src/serinv/block_primitive/syherk.py @@ -0,0 +1,71 @@ +from serinv import _get_module_from_array + +import numpy as np +from numpy.linalg import matmul + +from scipy.linalg.blas import get_blas_funcs +from scipy.linalg._misc import _datacopied +from scipy.linalg._decomp import _asarray_validated + +try: + import cupy as cp + from cupy_backends.cuda.libs import cublas + from cupy import _core + from cupy.cuda import device +except (ImportError, ImportWarning, ModuleNotFoundError): + pass + +def syherk(a, c=None, trans=0, lower = False, unit_diagonal=False, + overwrite_c=False, check_finite=False, side=0): + """Wrapper for the trsm function to call depending on wheter the solve happens on the host or the device + + For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept + plus the side parameter which can either be 0 or 1 for left or right hand side + """ + xp, la = _get_module_from_array(a) + if xp == np: + return matmul_syherk_host(a, c, trans, lower, unit_diagonal, overwrite_c, check_finite, side) + elif xp == cp: + return matmul_syherk_device(a, c, trans, lower, unit_diagonal, overwrite_c, check_finite, side) + else: + ModuleNotFoundError("Unknown Module") + +def matmul_syherk_host(a, c=None, alpha=1.0, beta=1.0, trans=0, lower=False, unit_diagonal=False, + overwrite_c=False, check_finite=True, side=0): + """Computes out = alpha * op(a) @ op(a)^T + beta * b + + op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', + op(a) = a.T.conj() if transa is 'C'. + """ + + a1 = _asarray_validated(a, check_finite=check_finite) + if c is None: + c1 = None + else: + c1 = _asarray_validated(c, check_finite=check_finite) + + + if a1.shape[0] != c1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and c {c1.shape} are incompatible') + + + overwrite_c = overwrite_c or _datacopied(c1, c) + + x = _syherk(a1, c1, alpha, beta, trans, lower, unit_diagonal, overwrite_c, side) + return x + + +# syherk without the input validation +def _syherk(a1, c1=None, alpha=1.0, beta=0.0, trans=0, lower=False, unit_diagonal=False, + overwrite_c=False, side=0): + + trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) + + if np.iscomplexobj(a1): + syherk, = get_blas_funcs(('herk'), (a1, c1)) + else: + syherk, = get_blas_funcs(('syrk'), (a1, c1)) + + out = syherk(alpha, a1, beta, c1, trans, lower, overwrite_c) + + return out \ No newline at end of file diff --git a/src/serinv/block_primitive/trsm.py b/src/serinv/block_primitive/trsm.py index 5f447e27..0185bed9 100644 --- a/src/serinv/block_primitive/trsm.py +++ b/src/serinv/block_primitive/trsm.py @@ -306,18 +306,13 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) trsm, = get_blas_funcs(('trsm',), (a1, b1)) - if a1.dtype.char in 'fd': - dtype = a1.dtype - else: - dtype = np.promote_types(a1.dtype.char, 'f') - - alpha = 1 + alpha = 1.0 if a1.flags.f_contiguous or trans == 2: x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=lower, trans_a=trans, diag=unit_diagonal, side=side) else: - # transposed system is solved since trtrs expects Fortran ordering + # transposed system is solved since trsm expects Fortran ordering x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, trans_a=not trans, diag=unit_diagonal, side=side) From 4e6bbde3ea75a3c70a922c5a095e4a0834b04060 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 08:30:55 +0000 Subject: [PATCH 089/157] added another print for debug --- src/serinv/algs/pobtaf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b81680b3..06f7731f 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -140,6 +140,7 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T + print(A_diagonal_blocks[i + 1, :, :]) A_diagonal_blocks[i + 1, :, :] = ( gemm( L_lower_diagonal_blocks[i, :, :], From f8ccb36ffe94e05e5a2cb938a538a33ab715b721 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 08:33:09 +0000 Subject: [PATCH 090/157] added another print --- src/serinv/algs/pobtaf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 06f7731f..e6ad9b74 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -141,6 +141,7 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T print(A_diagonal_blocks[i + 1, :, :]) + print(L_lower_diagonal_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T) A_diagonal_blocks[i + 1, :, :] = ( gemm( L_lower_diagonal_blocks[i, :, :], From 256e3d0c2d57ae5a4acfed3d1aaa1901d64cb7d6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:01:36 +0000 Subject: [PATCH 091/157] implemented syherk. sadly it's not yet useful --- src/serinv/block_primitive/__init__.py | 4 +- src/serinv/block_primitive/syherk.py | 139 ++++++++++++++++++++++--- 2 files changed, 130 insertions(+), 13 deletions(-) diff --git a/src/serinv/block_primitive/__init__.py b/src/serinv/block_primitive/__init__.py index b8cf0541..4d3895ec 100644 --- a/src/serinv/block_primitive/__init__.py +++ b/src/serinv/block_primitive/__init__.py @@ -1,7 +1,9 @@ from serinv.block_primitive.gemm import gemm from serinv.block_primitive.trsm import trsm +from serinv.block_primitive.trsm import syherk __all__ = [ "gemm", - "trsm" + "trsm", + "syherk" ] \ No newline at end of file diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index d708ff11..388a6b5f 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -10,13 +10,11 @@ try: import cupy as cp from cupy_backends.cuda.libs import cublas - from cupy import _core from cupy.cuda import device except (ImportError, ImportWarning, ModuleNotFoundError): pass -def syherk(a, c=None, trans=0, lower = False, unit_diagonal=False, - overwrite_c=False, check_finite=False, side=0): +def syherk(a, c=None, alpha=1.0, beta=0.0, trans=0, lower = False): """Wrapper for the trsm function to call depending on wheter the solve happens on the host or the device For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept @@ -24,9 +22,9 @@ def syherk(a, c=None, trans=0, lower = False, unit_diagonal=False, """ xp, la = _get_module_from_array(a) if xp == np: - return matmul_syherk_host(a, c, trans, lower, unit_diagonal, overwrite_c, check_finite, side) + return matmul_syherk_host(a, c, alpha, beta, trans, lower) elif xp == cp: - return matmul_syherk_device(a, c, trans, lower, unit_diagonal, overwrite_c, check_finite, side) + return matmul_syherk_device(a, trans, c, alpha, beta, lower) else: ModuleNotFoundError("Unknown Module") @@ -43,11 +41,6 @@ def matmul_syherk_host(a, c=None, alpha=1.0, beta=1.0, trans=0, lower=False, uni c1 = None else: c1 = _asarray_validated(c, check_finite=check_finite) - - - if a1.shape[0] != c1.shape[0]: - raise ValueError(f'shapes of a {a1.shape} and c {c1.shape} are incompatible') - overwrite_c = overwrite_c or _datacopied(c1, c) @@ -56,8 +49,8 @@ def matmul_syherk_host(a, c=None, alpha=1.0, beta=1.0, trans=0, lower=False, uni # syherk without the input validation -def _syherk(a1, c1=None, alpha=1.0, beta=0.0, trans=0, lower=False, unit_diagonal=False, - overwrite_c=False, side=0): +def _syherk(a1, c1=None, alpha=1.0, beta=0.0, trans=0, lower=False, + overwrite_c=False): trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) @@ -68,4 +61,126 @@ def _syherk(a1, c1=None, alpha=1.0, beta=0.0, trans=0, lower=False, unit_diagona out = syherk(alpha, a1, beta, c1, trans, lower, overwrite_c) + return out + + + +# Util functions for cupy gemm +def _trans_to_cublas_op(trans): + if trans == 'N' or trans == cublas.CUBLAS_OP_N: + trans = cublas.CUBLAS_OP_N + elif trans == 'T' or trans == cublas.CUBLAS_OP_T: + trans = cublas.CUBLAS_OP_T + elif trans == 'C' or trans == cublas.CUBLAS_OP_C: + trans = cublas.CUBLAS_OP_C + else: + raise TypeError('invalid trans (actual: {})'.format(trans)) + return trans + +def _decide_ld_and_trans(a, trans): + ld = None + if trans in (cublas.CUBLAS_OP_N, cublas.CUBLAS_OP_T): + if a._f_contiguous: + ld = a.shape[0] + elif a._c_contiguous: + ld = a.shape[1] + trans = 1 - trans + return ld, trans + +def _get_scalar_ptr(a, dtype): + if isinstance(a, cp.ndarray): + if a.dtype != dtype: + a = cp.array(a, dtype=dtype) + a_ptr = a.data.ptr + else: + if not (isinstance(a, np.ndarray) and a.dtype == dtype): + a = np.array(a, dtype=dtype) + a_ptr = a.ctypes.data + return a, a_ptr +# Util functions for cupy gemm end + +def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=False): + """Computes out := alpha*op1(a)*op2(a) + beta*out + + op1(a) = a if trans is 'N', op2(a) = a.T if transa is 'N' + op1(a) = a.T if trans is 'T', op2(a) = a if transa is 'T' + lower specifies whether the upper or lower triangular + part of the array out is to be referenced + """ + assert a.ndim == 2 + dtype = a.dtype.char + if dtype == 'f': + func = cublas.ssyrk + elif dtype == 'd': + func = cublas.dsyrk + elif dtype == 'F': + func = cublas.cherk + elif dtype == 'D': + func = cublas.zherk + else: + raise TypeError('invalid dtype') + + trans = _trans_to_cublas_op(trans) + if trans == cublas.CUBLAS_OP_N: + n, k = a.shape + else: + k, n = a.shape + if out is None: + out = cp.zeros((n, n), dtype=dtype, order='F') + beta = 0.0 + else: + assert out.ndim == 2 + assert out.shape == (n, n) + assert out.dtype == dtype + + if lower: + uplo = cublas.CUBLAS_FILL_MODE_LOWER + else: + uplo = cublas.CUBLAS_FILL_MODE_UPPER + + alpha, alpha_ptr = _get_scalar_ptr(alpha, a.dtype) + beta, beta_ptr = _get_scalar_ptr(beta, a.dtype) + handle = device.get_cublas_handle() + orig_mode = cublas.getPointerMode(handle) + if isinstance(alpha, cp.ndarray) or isinstance(beta, cp.ndarray): + if not isinstance(alpha, cp.ndarray): + alpha = cp.array(alpha) + alpha_ptr = alpha.data.ptr + if not isinstance(beta, cp.ndarray): + beta = cp.array(beta) + beta_ptr = beta.data.ptr + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_DEVICE) + else: + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_HOST) + + lda, trans = _decide_ld_and_trans(a, trans) + ldo, _ = _decide_ld_and_trans(out, trans) + if out._c_contiguous: + if not a._c_contiguous: + a = a.copy(order='C') + trans = 1 - trans + lda = a.shape[1] + try: + func(handle, 1 - uplo, trans, n, k, + alpha_ptr, a.data.ptr, lda, + beta_ptr, out.data.ptr, ldo) + finally: + cublas.setPointerMode(handle, orig_mode) + + else: + if not a._f_contiguous: + a = a.copy(order='F') + lda = a.shape[0] + trans = 1 - trans + c = out + if not out._f_contiguous: + c = out.copy(order='F') + try: + func(handle, uplo, trans, n, k, + alpha_ptr, a.data.ptr, lda, + beta_ptr, out.data.ptr, ldo) + finally: + cublas.setPointerMode(handle, orig_mode) + if not out._f_contiguous: + out[...] = c return out \ No newline at end of file From 9d8ea282eb507659ad8b77863cfd3c0b35cb3ec8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:04:53 +0000 Subject: [PATCH 092/157] removed debug prints --- src/serinv/algs/pobtaf.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index e6ad9b74..0975d58a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -140,8 +140,6 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T - print(A_diagonal_blocks[i + 1, :, :]) - print(L_lower_diagonal_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T) A_diagonal_blocks[i + 1, :, :] = ( gemm( L_lower_diagonal_blocks[i, :, :], @@ -150,8 +148,7 @@ def _pobtaf( trans_b='C', alpha=-1.0, beta=1.0 ) ) - print(A_diagonal_blocks[i + 1, :, :]) - raise ValueError("TEST") + # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( gemm( From 588533817e416ff4d025b6eb76f1da40c18dbec5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:33:22 +0000 Subject: [PATCH 093/157] added test error --- src/serinv/algs/pobtaf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 0975d58a..8e98793f 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -256,6 +256,7 @@ def _pobtaf_permuted( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T + print(L_lower_diagonal_blocks[i, :, :]) A_diagonal_blocks[i + 1, :, :] = ( gemm( L_lower_diagonal_blocks[i, :, :], @@ -264,7 +265,7 @@ def _pobtaf_permuted( trans_b='C', alpha=-1.0, beta=1.0 ) ) - + raise ValueError("TEST") # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( gemm( From 1436b7f149d4d6c085a650ae2f45ba452099d3ea Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:33:58 +0000 Subject: [PATCH 094/157] fixed typo --- src/serinv/block_primitive/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/__init__.py b/src/serinv/block_primitive/__init__.py index 4d3895ec..89b14a73 100644 --- a/src/serinv/block_primitive/__init__.py +++ b/src/serinv/block_primitive/__init__.py @@ -1,6 +1,6 @@ from serinv.block_primitive.gemm import gemm from serinv.block_primitive.trsm import trsm -from serinv.block_primitive.trsm import syherk +from serinv.block_primitive.syherk import syherk __all__ = [ "gemm", From b8a5a25be5c6dd5a4b4420cef335aa18f049caa8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:37:04 +0000 Subject: [PATCH 095/157] moved test --- src/serinv/algs/pobtaf.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 8e98793f..73583a25 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -126,6 +126,7 @@ def _pobtaf( ) ) + # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( @@ -148,7 +149,7 @@ def _pobtaf( trans_b='C', alpha=-1.0, beta=1.0 ) ) - + print(A_diagonal_blocks[i + 1, :, :]) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( gemm( @@ -168,6 +169,8 @@ def _pobtaf( trans_b='C', alpha=-1.0, beta=1.0 ) ) + print(A_arrow_tip_block[:, :]) + raise ValueError("TEST") if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) @@ -256,7 +259,6 @@ def _pobtaf_permuted( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T - print(L_lower_diagonal_blocks[i, :, :]) A_diagonal_blocks[i + 1, :, :] = ( gemm( L_lower_diagonal_blocks[i, :, :], @@ -265,7 +267,6 @@ def _pobtaf_permuted( trans_b='C', alpha=-1.0, beta=1.0 ) ) - raise ValueError("TEST") # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( gemm( From 2bf27a49f2615a2cb293e93c983fdcf7b3048dc9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:43:19 +0000 Subject: [PATCH 096/157] attempt for using syherk --- src/serinv/algs/pobtaf.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 73583a25..96a8edc0 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -7,7 +7,7 @@ _get_cholesky, ) -from serinv.block_primitive import trsm, gemm +from serinv.block_primitive import trsm, gemm, syherk def pobtaf( A_diagonal_blocks: ArrayLike, @@ -170,7 +170,6 @@ def _pobtaf( ) ) print(A_arrow_tip_block[:, :]) - raise ValueError("TEST") if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) @@ -188,14 +187,21 @@ def _pobtaf( ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} + # A_arrow_tip_block[:, :] = ( + # gemm( + # L_lower_arrow_blocks[-1, :, :], + # L_lower_arrow_blocks[-1, :, :], + # A_arrow_tip_block[:, :], + # trans_b='C', alpha=-1.0, beta=1.0 + # ) + # ) + A_arrow_tip_block[:, :] = ( - gemm( - L_lower_arrow_blocks[-1, :, :], + syherk( L_lower_arrow_blocks[-1, :, :], A_arrow_tip_block[:, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True ) - ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) From e62d9ff84ad24234af68ca10fa31ccdcdd30c60e Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:43:44 +0000 Subject: [PATCH 097/157] missing parenthesis --- src/serinv/algs/pobtaf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 96a8edc0..b7fb8a06 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -202,6 +202,7 @@ def _pobtaf( A_arrow_tip_block[:, :], alpha=-1.0, beta=1.0, lower=True ) + ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) From bc52ddc6929e962562d7112d3c6b7e5f8635cebc Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:45:22 +0000 Subject: [PATCH 098/157] changed input for _syherk --- src/serinv/block_primitive/syherk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 388a6b5f..90085098 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -44,7 +44,7 @@ def matmul_syherk_host(a, c=None, alpha=1.0, beta=1.0, trans=0, lower=False, uni overwrite_c = overwrite_c or _datacopied(c1, c) - x = _syherk(a1, c1, alpha, beta, trans, lower, unit_diagonal, overwrite_c, side) + x = _syherk(a1, c1, alpha, beta, trans, lower, overwrite_c) return x From cf3f7221e98eeaee9ddd86c9065463d1d3af6cfc Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:46:08 +0000 Subject: [PATCH 099/157] removed iteration error --- src/serinv/block_primitive/syherk.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 90085098..4478c1fb 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -55,9 +55,9 @@ def _syherk(a1, c1=None, alpha=1.0, beta=0.0, trans=0, lower=False, trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) if np.iscomplexobj(a1): - syherk, = get_blas_funcs(('herk'), (a1, c1)) + syherk = get_blas_funcs(('herk'), (a1, c1)) else: - syherk, = get_blas_funcs(('syrk'), (a1, c1)) + syherk = get_blas_funcs(('syrk'), (a1, c1)) out = syherk(alpha, a1, beta, c1, trans, lower, overwrite_c) From 3b92a1de4e6ac287192661de8fb80e147c03decc Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 11:41:02 +0000 Subject: [PATCH 100/157] attempt at implementing herk --- .../cupyfix_backends/cuda/cupy_cublas.h | 24 +++ .../cupyfix_backends/cuda/libs/cublas.pxd | 34 ++++ .../cupyfix_backends/cuda/libs/cublas.pyx | 112 ++++++++++++ .../cupyfix_backends/cupy_blas.h | 17 ++ .../cupyfix_backends/cupy_complex.h | 17 ++ .../cupyfix_backends/hip/cupy_cuComplex.h | 20 +++ .../cupyfix_backends/hip/cupy_hip_common.h | 161 ++++++++++++++++++ .../cupyfix_backends/hip/cupy_hipblas.h | 89 ++++++++++ .../cupyfix_backends/stub/cupy_cuComplex.h | 22 +++ .../cupyfix_backends/stub/cupy_cublas.h | 21 +++ src/serinv/block_primitive/syherk.py | 6 +- 11 files changed, 521 insertions(+), 2 deletions(-) create mode 100644 src/serinv/block_primitive/cupyfix_backends/cuda/cupy_cublas.h create mode 100644 src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pxd create mode 100644 src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pyx create mode 100644 src/serinv/block_primitive/cupyfix_backends/cupy_blas.h create mode 100644 src/serinv/block_primitive/cupyfix_backends/cupy_complex.h create mode 100644 src/serinv/block_primitive/cupyfix_backends/hip/cupy_cuComplex.h create mode 100644 src/serinv/block_primitive/cupyfix_backends/hip/cupy_hip_common.h create mode 100644 src/serinv/block_primitive/cupyfix_backends/hip/cupy_hipblas.h create mode 100644 src/serinv/block_primitive/cupyfix_backends/stub/cupy_cuComplex.h create mode 100644 src/serinv/block_primitive/cupyfix_backends/stub/cupy_cublas.h diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/cupy_cublas.h b/src/serinv/block_primitive/cupyfix_backends/cuda/cupy_cublas.h new file mode 100644 index 00000000..c3d4874b --- /dev/null +++ b/src/serinv/block_primitive/cupyfix_backends/cuda/cupy_cublas.h @@ -0,0 +1,24 @@ +#ifndef INCLUDE_GUARD_CUDA_CUPY_CUBLAS_H +#define INCLUDE_GUARD_CUDA_CUPY_CUBLAS_H + +#include +#include + +#if CUDA_VERSION >= 11000 + +#define cublasGemmEx_v11 cublasGemmEx +#define cublasGemmStridedBatchedEx_v11 cublasGemmStridedBatchedEx + +#else + +typedef enum{} cublasComputeType_t; +cublasStatus_t cublasGemmEx_v11(...) { + return CUBLAS_STATUS_NOT_SUPPORTED; +} +cublasStatus_t cublasGemmStridedBatchedEx_v11(...) { + return CUBLAS_STATUS_NOT_SUPPORTED; +} + +#endif // if CUDA_VERSION >= 11000 + +#endif // #ifndef INCLUDE_GUARD_CUDA_CUPY_CUBLAS_H \ No newline at end of file diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pxd b/src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pxd new file mode 100644 index 00000000..213630de --- /dev/null +++ b/src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pxd @@ -0,0 +1,34 @@ +"""Thin wrapper of CUBLAS.""" +from libc.stdint cimport intptr_t + + +############################################################################### +# Types +############################################################################### + +cdef extern from *: + ctypedef void* cuComplexPtr 'cuComplex*' + ctypedef void* cuDoubleComplexPtr 'cuDoubleComplex*' + + +cdef extern from *: + ctypedef void* Handle 'cublasHandle_t' + + ctypedef int DiagType 'cublasDiagType_t' + ctypedef int FillMode 'cublasFillMode_t' + ctypedef int Operation 'cublasOperation_t' + ctypedef int PointerMode 'cublasPointerMode_t' + ctypedef int SideMode 'cublasSideMode_t' + ctypedef int GemmAlgo 'cublasGemmAlgo_t' + ctypedef int Math 'cublasMath_t' + ctypedef int ComputeType 'cublasComputeType_t' + +############################################################################### +# BLAS Level 3 +############################################################################### + + +cpdef cherk(intptr_t handle, int uplo, int trans, int n, int k, + size_t alpha, size_t A, int lda, size_t beta, size_t C, int ldc) +cpdef zherk(intptr_t handle, int uplo, int trans, int n, int k, + size_t alpha, size_t A, int lda, size_t beta, size_t C, int ldc) diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pyx new file mode 100644 index 00000000..f406b3a8 --- /dev/null +++ b/src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pyx @@ -0,0 +1,112 @@ +# distutils: language = c++ + +"""Thin wrapper of CUBLAS.""" + +cimport cython # NOQA + +from cupy_backends.cuda.api cimport runtime +from cupy_backends.cuda cimport stream as stream_module + +############################################################################### +# Extern +############################################################################### + +cdef extern from '../../cupy_complex.h': + ctypedef struct cuComplex 'cuComplex': + float x, y + + ctypedef struct cuDoubleComplex 'cuDoubleComplex': + double x, y + +cdef extern from '../../cupy_blas.h' nogil: + ctypedef void* Stream 'cudaStream_t' + ctypedef int DataType 'cudaDataType' + + # BLAS Level 3 + int cublasCherk( + Handle handle, FillMode uplo, Operation trans, int n, int k, + cuComplex* alpha, cuComplex* A, int lda, + cuComplex* beta, cuComplex* C, int ldc) + int cublasZherk( + Handle handle, FillMode uplo, Operation trans, int n, int k, + cuDoubleComplex* alpha, cuDoubleComplex* A, int lda, + cuDoubleComplex* beta, cuDoubleComplex* C, int ldc) + +############################################################################### +# Error handling +############################################################################### + +cdef dict STATUS = { + 0: 'CUBLAS_STATUS_SUCCESS', + 1: 'CUBLAS_STATUS_NOT_INITIALIZED', + 3: 'CUBLAS_STATUS_ALLOC_FAILED', + 7: 'CUBLAS_STATUS_INVALID_VALUE', + 8: 'CUBLAS_STATUS_ARCH_MISMATCH', + 11: 'CUBLAS_STATUS_MAPPING_ERROR', + 13: 'CUBLAS_STATUS_EXECUTION_FAILED', + 14: 'CUBLAS_STATUS_INTERNAL_ERROR', + 15: 'CUBLAS_STATUS_NOT_SUPPORTED', + 16: 'CUBLAS_STATUS_LICENSE_ERROR', +} + + +cdef dict HIP_STATUS = { + 0: 'HIPBLAS_STATUS_SUCCESS', + 1: 'HIPBLAS_STATUS_NOT_INITIALIZED', + 2: 'HIPBLAS_STATUS_ALLOC_FAILED', + 3: 'HIPBLAS_STATUS_INVALID_VALUE', + 4: 'HIPBLAS_STATUS_MAPPING_ERROR', + 5: 'HIPBLAS_STATUS_EXECUTION_FAILED', + 6: 'HIPBLAS_STATUS_INTERNAL_ERROR', + 7: 'HIPBLAS_STATUS_NOT_SUPPORTED', + 8: 'HIPBLAS_STATUS_ARCH_MISMATCH', + 9: 'HIPBLAS_STATUS_HANDLE_IS_NULLPTR', +} + + +class CUBLASError(RuntimeError): + + def __init__(self, status): + self.status = status + cdef str err + if runtime._is_hip_environment: + err = HIP_STATUS[status] + else: + err = STATUS[status] + super(CUBLASError, self).__init__(err) + + def __reduce__(self): + return (type(self), (self.status,)) + + +@cython.profile(False) +cpdef inline check_status(int status): + if status != 0: + raise CUBLASError(status) + + + +############################################################################### +# BLAS Level 3 +############################################################################### + +cpdef cherk(intptr_t handle, int uplo, int trans, int n, int k, + size_t alpha, size_t A, int lda, size_t beta, size_t C, int ldc): + _setStream(handle) + with nogil: + status = cublasCherk( + handle, uplo, trans, n, k, + alpha, A, lda, + beta, C, ldc) + check_status(status) + + +cpdef zherk(intptr_t handle, int uplo, int trans, int n, int k, + size_t alpha, size_t A, int lda, size_t beta, size_t C, int ldc): + _setStream(handle) + with nogil: + status = cublasZherk( + handle, uplo, trans, n, k, + alpha, A, lda, + beta, C, ldc) + check_status(status) \ No newline at end of file diff --git a/src/serinv/block_primitive/cupyfix_backends/cupy_blas.h b/src/serinv/block_primitive/cupyfix_backends/cupy_blas.h new file mode 100644 index 00000000..ad2cffa5 --- /dev/null +++ b/src/serinv/block_primitive/cupyfix_backends/cupy_blas.h @@ -0,0 +1,17 @@ +#ifndef INCLUDE_GUARD_CUPY_CUBLAS_H +#define INCLUDE_GUARD_CUPY_CUBLAS_H + +#if CUPY_USE_HIP + +#include "hip/cupy_hipblas.h" + +#elif !defined(CUPY_NO_CUDA) + +#include "cuda/cupy_cublas.h" + +#else // #ifndef CUPY_NO_CUDA + +#include "stub/cupy_cublas.h" + +#endif // #ifndef CUPY_NO_CUDA +#endif // #ifndef INCLUDE_GUARD_CUPY_CUBLAS_H \ No newline at end of file diff --git a/src/serinv/block_primitive/cupyfix_backends/cupy_complex.h b/src/serinv/block_primitive/cupyfix_backends/cupy_complex.h new file mode 100644 index 00000000..5c7efed9 --- /dev/null +++ b/src/serinv/block_primitive/cupyfix_backends/cupy_complex.h @@ -0,0 +1,17 @@ +#ifndef INCLUDE_GUARD_CUPY_COMPLEX_H +#define INCLUDE_GUARD_CUPY_COMPLEX_H + +#ifdef CUPY_USE_HIP + +#include "hip/cupy_cuComplex.h" + +#elif !defined(CUPY_NO_CUDA) + +#include + +#else // #if !defined(CUPY_NO_CUDA) || !defined(CUPY_USE_HIP) + +#include "stub/cupy_cuComplex.h" + +#endif // #ifndef CUPY_NO_CUDA +#endif // #ifndef INCLUDE_GUARD_CUPY_COMPLEX_H \ No newline at end of file diff --git a/src/serinv/block_primitive/cupyfix_backends/hip/cupy_cuComplex.h b/src/serinv/block_primitive/cupyfix_backends/hip/cupy_cuComplex.h new file mode 100644 index 00000000..dfb6006d --- /dev/null +++ b/src/serinv/block_primitive/cupyfix_backends/hip/cupy_cuComplex.h @@ -0,0 +1,20 @@ +#ifndef INCLUDE_GUARD_HIP_CUPY_COMPLEX_H +#define INCLUDE_GUARD_HIP_CUPY_COMPLEX_H + +extern "C" { + +/////////////////////////////////////////////////////////////////////////////// +// cuComplex.h +/////////////////////////////////////////////////////////////////////////////// + +struct cuComplex{ + float x, y; +}; + +struct cuDoubleComplex{ + double x, y; +}; + +} // extern "C" + +#endif // #ifndef INCLUDE_GUARD_HIP_CUPY_COMPLEX_H \ No newline at end of file diff --git a/src/serinv/block_primitive/cupyfix_backends/hip/cupy_hip_common.h b/src/serinv/block_primitive/cupyfix_backends/hip/cupy_hip_common.h new file mode 100644 index 00000000..d1732ddc --- /dev/null +++ b/src/serinv/block_primitive/cupyfix_backends/hip/cupy_hip_common.h @@ -0,0 +1,161 @@ +#ifndef INCLUDE_GUARD_HIP_CUPY_COMMON_H +#define INCLUDE_GUARD_HIP_CUPY_COMMON_H + +#include +#include +#include + +#define CUDA_VERSION 0 + +extern "C" { + +/////////////////////////////////////////////////////////////////////////////// +// cuda.h +/////////////////////////////////////////////////////////////////////////////// + +typedef int CUdevice; +typedef hipError_t CUresult; +// Conditionally define CUDA_SUCCESS only if it's not defined +#ifndef CUDA_SUCCESS +const CUresult CUDA_SUCCESS = static_cast(0); +#endif +enum CUjit_option {}; +enum CUjitInputType {}; +enum CUarray_format {}; +enum CUaddress_mode {}; +enum CUfilter_mode {}; + + +typedef hipDeviceptr_t CUdeviceptr; +struct CUlinkState_st; + + +typedef hipCtx_t CUcontext; +typedef hipEvent_t CUevent; +typedef hipEvent_t cudaEvent_t; +typedef hipFunction_t CUfunction; +typedef hipFunction_attribute CUfunction_attribute; +typedef hipModule_t CUmodule; +typedef hipStream_t CUstream; +typedef hipStream_t cudaStream_t; +#if HIP_VERSION >= 40300000 +typedef hipGraph_t cudaGraph_t; +typedef hipGraphNode_t cudaGraphNode_t; +typedef hipGraphExec_t cudaGraphExec_t; +#else +typedef void* cudaGraph_t; +typedef void* cudaGraphNode_t; +typedef void* cudaGraphExec_t; +#endif +typedef struct CUlinkState_st* CUlinkState; +typedef struct CUarray_st* CUarray; +struct CUDA_ARRAY_DESCRIPTOR { + CUarray_format Format; + size_t Height; + unsigned int NumChannels; + size_t Width; +}; + + +/////////////////////////////////////////////////////////////////////////////// +// cuda_runtime.h +/////////////////////////////////////////////////////////////////////////////// + +enum { + cudaDevAttrComputeCapabilityMajor + = hipDeviceAttributeComputeCapabilityMajor, + cudaDevAttrComputeCapabilityMinor + = hipDeviceAttributeComputeCapabilityMinor, +}; + +typedef hipError_t cudaError_t; +const CUresult cudaSuccess = static_cast(0); +const CUresult cudaErrorInvalidValue = hipErrorInvalidValue; +const CUresult cudaErrorMemoryAllocation = hipErrorMemoryAllocation; +const CUresult cudaErrorInvalidResourceHandle = hipErrorInvalidResourceHandle; +const CUresult cudaErrorContextIsDestroyed = hipErrorUnknown; // no counterpart in HIP +const CUresult cudaErrorPeerAccessAlreadyEnabled = hipErrorPeerAccessAlreadyEnabled; +typedef enum {} cudaDataType; +typedef hipDeviceAttribute_t cudaDeviceAttr; +typedef hipLimit_t cudaLimit; +typedef hipMemoryAdvise cudaMemoryAdvise; +typedef hipMemcpyKind cudaMemcpyKind; +typedef hipDeviceProp_t cudaDeviceProp; +typedef void* cudaMemPool_t; +enum cudaMemPoolAttr {}; + +typedef hipStreamCallback_t cudaStreamCallback_t; +typedef void (*cudaHostFn_t)(void* userData); +typedef hipPointerAttribute_t cudaPointerAttributes; + +typedef hipChannelFormatKind cudaChannelFormatKind; +typedef hipTextureObject_t cudaTextureObject_t; +typedef hipSurfaceObject_t cudaSurfaceObject_t; +typedef hipResourceType cudaResourceType; +typedef hipTextureAddressMode cudaTextureAddressMode; +typedef hipTextureFilterMode cudaTextureFilterMode; +typedef hipTextureReadMode cudaTextureReadMode; +typedef hipResourceViewDesc cudaResourceViewDesc; +typedef hipArray_t cudaArray_t; +typedef hipExtent cudaExtent; +typedef hipPos cudaPos; +typedef hipPitchedPtr cudaPitchedPtr; +typedef hipMipmappedArray_t cudaMipmappedArray_t; +typedef hipMemcpy3DParms cudaMemcpy3DParms; +typedef hipChannelFormatDesc cudaChannelFormatDesc; +typedef hipResourceDesc cudaResourceDesc; +typedef hipTextureDesc cudaTextureDesc; + +// IPC operations +typedef hipIpcMemHandle_st cudaIpcMemHandle_t; +typedef hipIpcEventHandle_st cudaIpcEventHandle_t; + + +/////////////////////////////////////////////////////////////////////////////// +// blas & lapack (hipBLAS/rocBLAS & rocSOLVER) +/////////////////////////////////////////////////////////////////////////////// + +/* As of ROCm 3.5.0 (this may have started earlier) many rocSOLVER helper functions + * are deprecated and using their counterparts from rocBLAS is recommended. In + * particular, rocSOLVER simply uses rocBLAS's handle for its API calls. This means + * they are much more integrated than cuBLAS and cuSOLVER do, so it is better to + * put all of the relevant function in one place. + */ + +// TODO(leofang): investigate if we should just remove the hipBLAS layer and use +// rocBLAS directly, since we need to expose its handle anyway + + +typedef hipblasHandle_t cublasHandle_t; + +typedef hipblasDiagType_t cublasDiagType_t; +typedef hipblasFillMode_t cublasFillMode_t; +typedef hipblasOperation_t cublasOperation_t; +typedef hipblasPointerMode_t cublasPointerMode_t; +typedef hipblasSideMode_t cublasSideMode_t; +typedef enum {} cublasGemmAlgo_t; +typedef enum {} cublasMath_t; +typedef int cudaDataType_t; +typedef hipblasStatus_t cublasStatus_t; + +// TODO(leofang): as of ROCm 3.5.0 this does not exist yet +typedef enum {} cublasComputeType_t; + +typedef rocblas_status cusolverStatus_t; +typedef rocblas_handle cusolverDnHandle_t; + + +/////////////////////////////////////////////////////////////////////////////// +// library_types.h +// (needed for supporting cusolver) +/////////////////////////////////////////////////////////////////////////////// + +typedef enum libraryPropertyType_t { + MAJOR_VERSION, + MINOR_VERSION, + PATCH_LEVEL +} libraryPropertyType; + +} // extern "C" + +#endif // #ifndef INCLUDE_GUARD_HIP_CUPY_COMMON_H \ No newline at end of file diff --git a/src/serinv/block_primitive/cupyfix_backends/hip/cupy_hipblas.h b/src/serinv/block_primitive/cupyfix_backends/hip/cupy_hipblas.h new file mode 100644 index 00000000..f48366c2 --- /dev/null +++ b/src/serinv/block_primitive/cupyfix_backends/hip/cupy_hipblas.h @@ -0,0 +1,89 @@ +#ifndef INCLUDE_GUARD_HIP_CUPY_HIPBLAS_H +#define INCLUDE_GUARD_HIP_CUPY_HIPBLAS_H + +#include "cupy_hip_common.h" +#include +#include // for HIP_VERSION +#include // for gcc 10 + + +extern "C" { + +/////////////////////////////////////////////////////////////////////////////// +// blas & lapack (hipBLAS/rocBLAS & rocSOLVER) +/////////////////////////////////////////////////////////////////////////////// + +/* As of ROCm 3.5.0 (this may have started earlier) many rocSOLVER helper functions + * are deprecated and using their counterparts from rocBLAS is recommended. In + * particular, rocSOLVER simply uses rocBLAS's handle for its API calls. This means + * they are much more integrated than cuBLAS and cuSOLVER are, so it is better to + * put all of the relevant function in one place. + */ + +// TODO(leofang): investigate if we should just remove the hipBLAS layer and use +// rocBLAS directly, since we need to expose its handle anyway + + +/* ---------- helpers ---------- */ +static hipblasOperation_t convert_hipblasOperation_t(cublasOperation_t op) { + return static_cast(static_cast(op) + 111); +} + +static hipblasFillMode_t convert_hipblasFillMode_t(cublasFillMode_t mode) { + switch(static_cast(mode)) { + case 0 /* CUBLAS_FILL_MODE_LOWER */: return HIPBLAS_FILL_MODE_LOWER; + case 1 /* CUBLAS_FILL_MODE_UPPER */: return HIPBLAS_FILL_MODE_UPPER; + default: throw std::runtime_error("unrecognized mode"); + } +} + +static hipblasDiagType_t convert_hipblasDiagType_t(cublasDiagType_t type) { + return static_cast(static_cast(type) + 131); +} + +static hipblasSideMode_t convert_hipblasSideMode_t(cublasSideMode_t mode) { + return static_cast(static_cast(mode) + 141); +} + +static hipblasDatatype_t convert_hipblasDatatype_t(cudaDataType_t type) { + switch(static_cast(type)) { + case 0 /* CUDA_R_32F */: return HIPBLAS_R_32F; + case 1 /* CUDA_R_64F */: return HIPBLAS_R_64F; + case 2 /* CUDA_R_16F */: return HIPBLAS_R_16F; + case 3 /* CUDA_R_8I */ : return HIPBLAS_R_8I; + case 4 /* CUDA_C_32F */: return HIPBLAS_C_32F; + case 5 /* CUDA_C_64F */: return HIPBLAS_C_64F; + case 6 /* CUDA_C_16F */: return HIPBLAS_C_16F; + case 7 /* CUDA_C_8I */ : return HIPBLAS_C_8I; + case 8 /* CUDA_R_8U */ : return HIPBLAS_R_8U; + case 9 /* CUDA_C_8U */ : return HIPBLAS_C_8U; + default: throw std::runtime_error("unrecognized type"); + } +} + +// BLAS Level 3 +cublasStatus_t cublasCherk(cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, int n, int k, + const cuComplex* alpha, const cuComplex* A,int lda, + const cuComplex* beta, cuComplex* C, int ldc) +{ + return hipblasCherk(handle, convert_hipblasFillMode_t(uplo), convert_hipblasOperation_t(trans), n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(beta), + reinterpret_cast(C), ldc); +} + +cublasStatus_t cublasZherk(cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, int n, int k, + const cuDoubleComplex* alpha, const cuDoubleComplex* A, int lda, + const cuDoubleComplex* beta, cuDoubleComplex* C, int ldc) +{ + return hipblasZherk(handle, convert_hipblasFillMode_t(uplo), convert_hipblasOperation_t(trans), n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(beta), + reinterpret_cast(C), ldc); +} + +} // extern "C" + +#endif // #ifndef INCLUDE_GUARD_HIP_CUPY_HIPBLAS_H \ No newline at end of file diff --git a/src/serinv/block_primitive/cupyfix_backends/stub/cupy_cuComplex.h b/src/serinv/block_primitive/cupyfix_backends/stub/cupy_cuComplex.h new file mode 100644 index 00000000..4e67db36 --- /dev/null +++ b/src/serinv/block_primitive/cupyfix_backends/stub/cupy_cuComplex.h @@ -0,0 +1,22 @@ +// This file is a stub header file of cuda for Read the Docs. + +#ifndef INCLUDE_GUARD_STUB_CUPY_COMPLEX_H +#define INCLUDE_GUARD_STUB_CUPY_COMPLEX_H + +extern "C" { + +/////////////////////////////////////////////////////////////////////////////// +// cuComplex.h +/////////////////////////////////////////////////////////////////////////////// + +struct cuComplex{ + float x, y; +}; + +struct cuDoubleComplex{ + double x, y; +}; + +} // extern "C" + +#endif // #ifndef INCLUDE_GUARD_STUB_CUPY_COMPLEX_H \ No newline at end of file diff --git a/src/serinv/block_primitive/cupyfix_backends/stub/cupy_cublas.h b/src/serinv/block_primitive/cupyfix_backends/stub/cupy_cublas.h new file mode 100644 index 00000000..d1831837 --- /dev/null +++ b/src/serinv/block_primitive/cupyfix_backends/stub/cupy_cublas.h @@ -0,0 +1,21 @@ +// This file is a stub header file of cuda for Read the Docs. + +#ifndef INCLUDE_GUARD_STUB_CUPY_CUBLAS_H +#define INCLUDE_GUARD_STUB_CUPY_CUBLAS_H + +#include "cupy_cuda_common.h" + +extern "C" { + +cublasStatus_t cublasCsyrk(...) { + return CUBLAS_STATUS_SUCCESS; +} + +cublasStatus_t cublasZsyrk(...) { + return CUBLAS_STATUS_SUCCESS; +} + + +} // extern "C" + +#endif // #ifndef INCLUDE_GUARD_STUB_CUPY_CUBLAS_H \ No newline at end of file diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 4478c1fb..34716749 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -11,6 +11,8 @@ import cupy as cp from cupy_backends.cuda.libs import cublas from cupy.cuda import device + + from cupyfix_backends import cublasfix except (ImportError, ImportWarning, ModuleNotFoundError): pass @@ -114,9 +116,9 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals elif dtype == 'd': func = cublas.dsyrk elif dtype == 'F': - func = cublas.cherk + func = cublasfix.cherk elif dtype == 'D': - func = cublas.zherk + func = cublasfix.zherk else: raise TypeError('invalid dtype') From f7ae1cf45305d514fcd000c7989d8060c9cdfa74 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 11:41:44 +0000 Subject: [PATCH 101/157] typo --- src/serinv/block_primitive/syherk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 34716749..3fcbf6a3 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -12,7 +12,7 @@ from cupy_backends.cuda.libs import cublas from cupy.cuda import device - from cupyfix_backends import cublasfix + from cupyfix_backends import cublas as cublasfix except (ImportError, ImportWarning, ModuleNotFoundError): pass From dc9b9c2ece7b2f08b62d8ac5da7c0dbee0a1126c Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 11:45:00 +0000 Subject: [PATCH 102/157] fixed path --- src/serinv/block_primitive/syherk.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 3fcbf6a3..1d7d6ef2 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,5 +1,7 @@ from serinv import _get_module_from_array +from cupyfix_backends.cuda.libs import cublas as cublasfix + import numpy as np from numpy.linalg import matmul @@ -12,7 +14,7 @@ from cupy_backends.cuda.libs import cublas from cupy.cuda import device - from cupyfix_backends import cublas as cublasfix + except (ImportError, ImportWarning, ModuleNotFoundError): pass From 59831e15f1f861966993ccae4df19d8f39f222f9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 11:47:19 +0000 Subject: [PATCH 103/157] created init files --- src/serinv/block_primitive/cupyfix_backends/__init__.pxd | 0 src/serinv/block_primitive/cupyfix_backends/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/serinv/block_primitive/cupyfix_backends/__init__.pxd create mode 100644 src/serinv/block_primitive/cupyfix_backends/__init__.py diff --git a/src/serinv/block_primitive/cupyfix_backends/__init__.pxd b/src/serinv/block_primitive/cupyfix_backends/__init__.pxd new file mode 100644 index 00000000..e69de29b diff --git a/src/serinv/block_primitive/cupyfix_backends/__init__.py b/src/serinv/block_primitive/cupyfix_backends/__init__.py new file mode 100644 index 00000000..e69de29b From 2c182baf8946bae44deba83a6ec5cc448f43563a Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 11:51:02 +0000 Subject: [PATCH 104/157] more init files --- src/serinv/block_primitive/cupyfix_backends/cuda/__init__.pxd | 0 src/serinv/block_primitive/cupyfix_backends/cuda/__init__.py | 0 src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init.pxd | 0 src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init__.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/serinv/block_primitive/cupyfix_backends/cuda/__init__.pxd create mode 100644 src/serinv/block_primitive/cupyfix_backends/cuda/__init__.py create mode 100644 src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init.pxd create mode 100644 src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init__.py diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/__init__.pxd b/src/serinv/block_primitive/cupyfix_backends/cuda/__init__.pxd new file mode 100644 index 00000000..e69de29b diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/__init__.py b/src/serinv/block_primitive/cupyfix_backends/cuda/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init.pxd b/src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init.pxd new file mode 100644 index 00000000..e69de29b diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init__.py b/src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init__.py new file mode 100644 index 00000000..e69de29b From defea4d678d4993981b9edd7fe23fad196da063a Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 11:55:21 +0000 Subject: [PATCH 105/157] moved cupy part to try --- src/serinv/block_primitive/syherk.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 1d7d6ef2..05677ad8 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,6 +1,6 @@ from serinv import _get_module_from_array -from cupyfix_backends.cuda.libs import cublas as cublasfix + import numpy as np from numpy.linalg import matmul @@ -14,7 +14,7 @@ from cupy_backends.cuda.libs import cublas from cupy.cuda import device - + from cupyfix_backends.cuda.libs import cublas as cublasfix except (ImportError, ImportWarning, ModuleNotFoundError): pass From 025673ef97a49e60503ea12c404c32d90156aae4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 12:03:33 +0000 Subject: [PATCH 106/157] put import back --- src/serinv/block_primitive/syherk.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 05677ad8..1d7d6ef2 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,6 +1,6 @@ from serinv import _get_module_from_array - +from cupyfix_backends.cuda.libs import cublas as cublasfix import numpy as np from numpy.linalg import matmul @@ -14,7 +14,7 @@ from cupy_backends.cuda.libs import cublas from cupy.cuda import device - from cupyfix_backends.cuda.libs import cublas as cublasfix + except (ImportError, ImportWarning, ModuleNotFoundError): pass From 0073cd0b653be0852947527cde57469940da0f5e Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 12:10:38 +0000 Subject: [PATCH 107/157] renamed cupyfix_backends to backends --- .../block_primitive/{cupyfix_backends => backends}/__init__.pxd | 0 .../block_primitive/{cupyfix_backends => backends}/__init__.py | 0 .../{cupyfix_backends => backends}/cuda/__init__.pxd | 0 .../{cupyfix_backends => backends}/cuda/__init__.py | 0 .../{cupyfix_backends => backends}/cuda/cupy_cublas.h | 0 .../{cupyfix_backends => backends}/cuda/libs/__init.pxd | 0 .../{cupyfix_backends => backends}/cuda/libs/__init__.py | 0 .../{cupyfix_backends => backends}/cuda/libs/cublas.pxd | 0 .../{cupyfix_backends => backends}/cuda/libs/cublas.pyx | 0 .../block_primitive/{cupyfix_backends => backends}/cupy_blas.h | 0 .../{cupyfix_backends => backends}/cupy_complex.h | 0 .../{cupyfix_backends => backends}/hip/cupy_cuComplex.h | 0 .../{cupyfix_backends => backends}/hip/cupy_hip_common.h | 0 .../{cupyfix_backends => backends}/hip/cupy_hipblas.h | 0 .../{cupyfix_backends => backends}/stub/cupy_cuComplex.h | 0 .../{cupyfix_backends => backends}/stub/cupy_cublas.h | 0 src/serinv/block_primitive/syherk.py | 2 +- 17 files changed, 1 insertion(+), 1 deletion(-) rename src/serinv/block_primitive/{cupyfix_backends => backends}/__init__.pxd (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/__init__.py (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/cuda/__init__.pxd (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/cuda/__init__.py (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/cuda/cupy_cublas.h (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/cuda/libs/__init.pxd (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/cuda/libs/__init__.py (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/cuda/libs/cublas.pxd (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/cuda/libs/cublas.pyx (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/cupy_blas.h (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/cupy_complex.h (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/hip/cupy_cuComplex.h (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/hip/cupy_hip_common.h (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/hip/cupy_hipblas.h (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/stub/cupy_cuComplex.h (100%) rename src/serinv/block_primitive/{cupyfix_backends => backends}/stub/cupy_cublas.h (100%) diff --git a/src/serinv/block_primitive/cupyfix_backends/__init__.pxd b/src/serinv/block_primitive/backends/__init__.pxd similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/__init__.pxd rename to src/serinv/block_primitive/backends/__init__.pxd diff --git a/src/serinv/block_primitive/cupyfix_backends/__init__.py b/src/serinv/block_primitive/backends/__init__.py similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/__init__.py rename to src/serinv/block_primitive/backends/__init__.py diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/__init__.pxd b/src/serinv/block_primitive/backends/cuda/__init__.pxd similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cuda/__init__.pxd rename to src/serinv/block_primitive/backends/cuda/__init__.pxd diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/__init__.py b/src/serinv/block_primitive/backends/cuda/__init__.py similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cuda/__init__.py rename to src/serinv/block_primitive/backends/cuda/__init__.py diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/cupy_cublas.h b/src/serinv/block_primitive/backends/cuda/cupy_cublas.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cuda/cupy_cublas.h rename to src/serinv/block_primitive/backends/cuda/cupy_cublas.h diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init.pxd b/src/serinv/block_primitive/backends/cuda/libs/__init.pxd similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init.pxd rename to src/serinv/block_primitive/backends/cuda/libs/__init.pxd diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init__.py b/src/serinv/block_primitive/backends/cuda/libs/__init__.py similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init__.py rename to src/serinv/block_primitive/backends/cuda/libs/__init__.py diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pxd b/src/serinv/block_primitive/backends/cuda/libs/cublas.pxd similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pxd rename to src/serinv/block_primitive/backends/cuda/libs/cublas.pxd diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/block_primitive/backends/cuda/libs/cublas.pyx similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pyx rename to src/serinv/block_primitive/backends/cuda/libs/cublas.pyx diff --git a/src/serinv/block_primitive/cupyfix_backends/cupy_blas.h b/src/serinv/block_primitive/backends/cupy_blas.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cupy_blas.h rename to src/serinv/block_primitive/backends/cupy_blas.h diff --git a/src/serinv/block_primitive/cupyfix_backends/cupy_complex.h b/src/serinv/block_primitive/backends/cupy_complex.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cupy_complex.h rename to src/serinv/block_primitive/backends/cupy_complex.h diff --git a/src/serinv/block_primitive/cupyfix_backends/hip/cupy_cuComplex.h b/src/serinv/block_primitive/backends/hip/cupy_cuComplex.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/hip/cupy_cuComplex.h rename to src/serinv/block_primitive/backends/hip/cupy_cuComplex.h diff --git a/src/serinv/block_primitive/cupyfix_backends/hip/cupy_hip_common.h b/src/serinv/block_primitive/backends/hip/cupy_hip_common.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/hip/cupy_hip_common.h rename to src/serinv/block_primitive/backends/hip/cupy_hip_common.h diff --git a/src/serinv/block_primitive/cupyfix_backends/hip/cupy_hipblas.h b/src/serinv/block_primitive/backends/hip/cupy_hipblas.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/hip/cupy_hipblas.h rename to src/serinv/block_primitive/backends/hip/cupy_hipblas.h diff --git a/src/serinv/block_primitive/cupyfix_backends/stub/cupy_cuComplex.h b/src/serinv/block_primitive/backends/stub/cupy_cuComplex.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/stub/cupy_cuComplex.h rename to src/serinv/block_primitive/backends/stub/cupy_cuComplex.h diff --git a/src/serinv/block_primitive/cupyfix_backends/stub/cupy_cublas.h b/src/serinv/block_primitive/backends/stub/cupy_cublas.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/stub/cupy_cublas.h rename to src/serinv/block_primitive/backends/stub/cupy_cublas.h diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 1d7d6ef2..f38fcc4f 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,6 +1,6 @@ from serinv import _get_module_from_array -from cupyfix_backends.cuda.libs import cublas as cublasfix +from backends.cuda.libs import cublas as cublasfix import numpy as np from numpy.linalg import matmul From a20eaec26806733294c30e7b6eb58a46ad5f16a8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 12:11:37 +0000 Subject: [PATCH 108/157] reverted change because it didn't help --- .../block_primitive/{backends => cupyfix_backends}/__init__.pxd | 0 .../block_primitive/{backends => cupyfix_backends}/__init__.py | 0 .../{backends => cupyfix_backends}/cuda/__init__.pxd | 0 .../{backends => cupyfix_backends}/cuda/__init__.py | 0 .../{backends => cupyfix_backends}/cuda/cupy_cublas.h | 0 .../{backends => cupyfix_backends}/cuda/libs/__init.pxd | 0 .../{backends => cupyfix_backends}/cuda/libs/__init__.py | 0 .../{backends => cupyfix_backends}/cuda/libs/cublas.pxd | 0 .../{backends => cupyfix_backends}/cuda/libs/cublas.pyx | 0 .../block_primitive/{backends => cupyfix_backends}/cupy_blas.h | 0 .../{backends => cupyfix_backends}/cupy_complex.h | 0 .../{backends => cupyfix_backends}/hip/cupy_cuComplex.h | 0 .../{backends => cupyfix_backends}/hip/cupy_hip_common.h | 0 .../{backends => cupyfix_backends}/hip/cupy_hipblas.h | 0 .../{backends => cupyfix_backends}/stub/cupy_cuComplex.h | 0 .../{backends => cupyfix_backends}/stub/cupy_cublas.h | 0 src/serinv/block_primitive/syherk.py | 2 +- 17 files changed, 1 insertion(+), 1 deletion(-) rename src/serinv/block_primitive/{backends => cupyfix_backends}/__init__.pxd (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/__init__.py (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/cuda/__init__.pxd (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/cuda/__init__.py (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/cuda/cupy_cublas.h (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/cuda/libs/__init.pxd (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/cuda/libs/__init__.py (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/cuda/libs/cublas.pxd (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/cuda/libs/cublas.pyx (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/cupy_blas.h (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/cupy_complex.h (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/hip/cupy_cuComplex.h (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/hip/cupy_hip_common.h (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/hip/cupy_hipblas.h (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/stub/cupy_cuComplex.h (100%) rename src/serinv/block_primitive/{backends => cupyfix_backends}/stub/cupy_cublas.h (100%) diff --git a/src/serinv/block_primitive/backends/__init__.pxd b/src/serinv/block_primitive/cupyfix_backends/__init__.pxd similarity index 100% rename from src/serinv/block_primitive/backends/__init__.pxd rename to src/serinv/block_primitive/cupyfix_backends/__init__.pxd diff --git a/src/serinv/block_primitive/backends/__init__.py b/src/serinv/block_primitive/cupyfix_backends/__init__.py similarity index 100% rename from src/serinv/block_primitive/backends/__init__.py rename to src/serinv/block_primitive/cupyfix_backends/__init__.py diff --git a/src/serinv/block_primitive/backends/cuda/__init__.pxd b/src/serinv/block_primitive/cupyfix_backends/cuda/__init__.pxd similarity index 100% rename from src/serinv/block_primitive/backends/cuda/__init__.pxd rename to src/serinv/block_primitive/cupyfix_backends/cuda/__init__.pxd diff --git a/src/serinv/block_primitive/backends/cuda/__init__.py b/src/serinv/block_primitive/cupyfix_backends/cuda/__init__.py similarity index 100% rename from src/serinv/block_primitive/backends/cuda/__init__.py rename to src/serinv/block_primitive/cupyfix_backends/cuda/__init__.py diff --git a/src/serinv/block_primitive/backends/cuda/cupy_cublas.h b/src/serinv/block_primitive/cupyfix_backends/cuda/cupy_cublas.h similarity index 100% rename from src/serinv/block_primitive/backends/cuda/cupy_cublas.h rename to src/serinv/block_primitive/cupyfix_backends/cuda/cupy_cublas.h diff --git a/src/serinv/block_primitive/backends/cuda/libs/__init.pxd b/src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init.pxd similarity index 100% rename from src/serinv/block_primitive/backends/cuda/libs/__init.pxd rename to src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init.pxd diff --git a/src/serinv/block_primitive/backends/cuda/libs/__init__.py b/src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init__.py similarity index 100% rename from src/serinv/block_primitive/backends/cuda/libs/__init__.py rename to src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init__.py diff --git a/src/serinv/block_primitive/backends/cuda/libs/cublas.pxd b/src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pxd similarity index 100% rename from src/serinv/block_primitive/backends/cuda/libs/cublas.pxd rename to src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pxd diff --git a/src/serinv/block_primitive/backends/cuda/libs/cublas.pyx b/src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pyx similarity index 100% rename from src/serinv/block_primitive/backends/cuda/libs/cublas.pyx rename to src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pyx diff --git a/src/serinv/block_primitive/backends/cupy_blas.h b/src/serinv/block_primitive/cupyfix_backends/cupy_blas.h similarity index 100% rename from src/serinv/block_primitive/backends/cupy_blas.h rename to src/serinv/block_primitive/cupyfix_backends/cupy_blas.h diff --git a/src/serinv/block_primitive/backends/cupy_complex.h b/src/serinv/block_primitive/cupyfix_backends/cupy_complex.h similarity index 100% rename from src/serinv/block_primitive/backends/cupy_complex.h rename to src/serinv/block_primitive/cupyfix_backends/cupy_complex.h diff --git a/src/serinv/block_primitive/backends/hip/cupy_cuComplex.h b/src/serinv/block_primitive/cupyfix_backends/hip/cupy_cuComplex.h similarity index 100% rename from src/serinv/block_primitive/backends/hip/cupy_cuComplex.h rename to src/serinv/block_primitive/cupyfix_backends/hip/cupy_cuComplex.h diff --git a/src/serinv/block_primitive/backends/hip/cupy_hip_common.h b/src/serinv/block_primitive/cupyfix_backends/hip/cupy_hip_common.h similarity index 100% rename from src/serinv/block_primitive/backends/hip/cupy_hip_common.h rename to src/serinv/block_primitive/cupyfix_backends/hip/cupy_hip_common.h diff --git a/src/serinv/block_primitive/backends/hip/cupy_hipblas.h b/src/serinv/block_primitive/cupyfix_backends/hip/cupy_hipblas.h similarity index 100% rename from src/serinv/block_primitive/backends/hip/cupy_hipblas.h rename to src/serinv/block_primitive/cupyfix_backends/hip/cupy_hipblas.h diff --git a/src/serinv/block_primitive/backends/stub/cupy_cuComplex.h b/src/serinv/block_primitive/cupyfix_backends/stub/cupy_cuComplex.h similarity index 100% rename from src/serinv/block_primitive/backends/stub/cupy_cuComplex.h rename to src/serinv/block_primitive/cupyfix_backends/stub/cupy_cuComplex.h diff --git a/src/serinv/block_primitive/backends/stub/cupy_cublas.h b/src/serinv/block_primitive/cupyfix_backends/stub/cupy_cublas.h similarity index 100% rename from src/serinv/block_primitive/backends/stub/cupy_cublas.h rename to src/serinv/block_primitive/cupyfix_backends/stub/cupy_cublas.h diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index f38fcc4f..1d7d6ef2 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,6 +1,6 @@ from serinv import _get_module_from_array -from backends.cuda.libs import cublas as cublasfix +from cupyfix_backends.cuda.libs import cublas as cublasfix import numpy as np from numpy.linalg import matmul From 5b4de18555b766d137df7b9fa2df56f388ae57f9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 12:18:05 +0000 Subject: [PATCH 109/157] debug import --- src/serinv/block_primitive/syherk.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 1d7d6ef2..d1be8233 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,5 +1,6 @@ from serinv import _get_module_from_array +import cupyfix_backends from cupyfix_backends.cuda.libs import cublas as cublasfix import numpy as np From 718ff3bf1710a8d2c111b152560b748512b87170 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 12:21:28 +0000 Subject: [PATCH 110/157] created new module to check if something is wrong --- src/serinv/block_primitive/cupyfix_backends/__init__.py | 6 ++++++ src/serinv/block_primitive/syherk.py | 2 +- src/serinv/block_primitive/trymod/__init__.py | 0 3 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 src/serinv/block_primitive/trymod/__init__.py diff --git a/src/serinv/block_primitive/cupyfix_backends/__init__.py b/src/serinv/block_primitive/cupyfix_backends/__init__.py index e69de29b..42455d73 100644 --- a/src/serinv/block_primitive/cupyfix_backends/__init__.py +++ b/src/serinv/block_primitive/cupyfix_backends/__init__.py @@ -0,0 +1,6 @@ +def foo(): + return 0 + +__all__ = [ + "foo", +] diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index d1be8233..1ca78594 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,6 +1,6 @@ from serinv import _get_module_from_array -import cupyfix_backends +import trymod from cupyfix_backends.cuda.libs import cublas as cublasfix import numpy as np diff --git a/src/serinv/block_primitive/trymod/__init__.py b/src/serinv/block_primitive/trymod/__init__.py new file mode 100644 index 00000000..e69de29b From 25611ef016fbc5f1a2cc3a83a6123df3acf51cec Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 12:22:24 +0000 Subject: [PATCH 111/157] updated test module --- src/serinv/block_primitive/cupyfix_backends/__init__.py | 5 ----- src/serinv/block_primitive/trymod/__init__.py | 6 ++++++ 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/serinv/block_primitive/cupyfix_backends/__init__.py b/src/serinv/block_primitive/cupyfix_backends/__init__.py index 42455d73..8b137891 100644 --- a/src/serinv/block_primitive/cupyfix_backends/__init__.py +++ b/src/serinv/block_primitive/cupyfix_backends/__init__.py @@ -1,6 +1 @@ -def foo(): - return 0 -__all__ = [ - "foo", -] diff --git a/src/serinv/block_primitive/trymod/__init__.py b/src/serinv/block_primitive/trymod/__init__.py index e69de29b..d167fa53 100644 --- a/src/serinv/block_primitive/trymod/__init__.py +++ b/src/serinv/block_primitive/trymod/__init__.py @@ -0,0 +1,6 @@ +def foo(): + return 0 + +__all__ = [ + "foo", +] \ No newline at end of file From 5a8f26828630c7541de950b4d2a5d22d6eb79ecf Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 12:28:33 +0000 Subject: [PATCH 112/157] attempt to fix import problem --- src/serinv/block_primitive/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/__init__.py b/src/serinv/block_primitive/__init__.py index 89b14a73..941b6c92 100644 --- a/src/serinv/block_primitive/__init__.py +++ b/src/serinv/block_primitive/__init__.py @@ -1,9 +1,11 @@ from serinv.block_primitive.gemm import gemm from serinv.block_primitive.trsm import trsm from serinv.block_primitive.syherk import syherk +import cupyfix_backends __all__ = [ "gemm", "trsm", - "syherk" + "syherk", + cupyfix_backends ] \ No newline at end of file From 01b53bcaeb5215c63086177d46e908ed595a5e21 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 12:28:57 +0000 Subject: [PATCH 113/157] debugging --- src/serinv/block_primitive/syherk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 1ca78594..48030b40 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,6 +1,6 @@ from serinv import _get_module_from_array -import trymod +#import trymod from cupyfix_backends.cuda.libs import cublas as cublasfix import numpy as np From bd0e04316bb4226d25db1cd2d3c62b692dae50fb Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 12:30:32 +0000 Subject: [PATCH 114/157] further debugging --- src/serinv/block_primitive/syherk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 48030b40..08716d6d 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,7 +1,7 @@ from serinv import _get_module_from_array #import trymod -from cupyfix_backends.cuda.libs import cublas as cublasfix +from block_primitive.cupyfix_backends.cuda.libs import cublas as cublasfix import numpy as np from numpy.linalg import matmul From 8334fdb5e7be53babfc546de153723ba2037dd88 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 12:31:04 +0000 Subject: [PATCH 115/157] more debug --- src/serinv/block_primitive/syherk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 08716d6d..db6bed00 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,7 +1,7 @@ from serinv import _get_module_from_array #import trymod -from block_primitive.cupyfix_backends.cuda.libs import cublas as cublasfix +from serinv.block_primitive.cupyfix_backends.cuda.libs import cublas as cublasfix import numpy as np from numpy.linalg import matmul From 32b61bb76c4d9ad3d8e727be32a9cc8573c196ff Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 12:37:56 +0000 Subject: [PATCH 116/157] try to import test module --- src/serinv/block_primitive/syherk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index db6bed00..9bd6f8ef 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,6 +1,6 @@ from serinv import _get_module_from_array -#import trymod +import serinv.block_primitive.trymod from serinv.block_primitive.cupyfix_backends.cuda.libs import cublas as cublasfix import numpy as np From 2c4d3003e65af9e0792ff83485906986db128b20 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 13:36:40 +0000 Subject: [PATCH 117/157] attempt at setup.py --- setup.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 setup.py diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..669754e6 --- /dev/null +++ b/setup.py @@ -0,0 +1,22 @@ +from setuptools import setup, Extension +from Cython.Build import cythonize + +ext = Extension( + name="cupyfix_backends.cuda.libs.cublas", + sources=["cupyfix_backends/cuda/libs/cublas.pyx", + "cupyfix_backends/cuda/cupy_cublas.h", + "cupyfix_backends/cuda/hip/cupy_cuComplex.h", + "cupyfix_backends/cuda/hip/cupy_hip_common.h", + "cupyfix_backends/cuda/hip/cupy_hipblas.h", + "cupyfix_backends/cuda/stub/cupy_cublas.h", + "cupyfix_backends/cuda/stub/cupy_cuComplex.h", + "cupyfix_backends/cuda/cupy_blas.h" + "cupyfix_backends/cuda/cupy_complex.h"], + include_dirs=["cupyfix_backends"], +) + +setup( + name="cupyfix_backends", + ext_modules=cythonize([ext]), + packages=["cupyfix_backend"], +) \ No newline at end of file From eb711442cee260fa2b85b4711eaea859ad4f69bf Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 13:39:40 +0000 Subject: [PATCH 118/157] changed c in cython --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 669754e6..a9d79ae9 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ from setuptools import setup, Extension -from Cython.Build import cythonize +from cython.Build import cythonize ext = Extension( name="cupyfix_backends.cuda.libs.cublas", From 7182951c43b6132757477b0303bdf9e28ff1a4c5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 13:41:46 +0000 Subject: [PATCH 119/157] reverted to big C --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a9d79ae9..669754e6 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ from setuptools import setup, Extension -from cython.Build import cythonize +from Cython.Build import cythonize ext = Extension( name="cupyfix_backends.cuda.libs.cublas", From 884fb3a70d30fcd1612a0e9125202d522bf30615 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 13:46:18 +0000 Subject: [PATCH 120/157] moved backend folder --- src/serinv/{block_primitive => }/cupyfix_backends/__init__.pxd | 0 src/serinv/{block_primitive => }/cupyfix_backends/__init__.py | 0 .../{block_primitive => }/cupyfix_backends/cuda/__init__.pxd | 0 .../{block_primitive => }/cupyfix_backends/cuda/__init__.py | 0 .../{block_primitive => }/cupyfix_backends/cuda/cupy_cublas.h | 0 .../{block_primitive => }/cupyfix_backends/cuda/libs/__init.pxd | 0 .../{block_primitive => }/cupyfix_backends/cuda/libs/__init__.py | 0 .../{block_primitive => }/cupyfix_backends/cuda/libs/cublas.pxd | 0 .../{block_primitive => }/cupyfix_backends/cuda/libs/cublas.pyx | 0 src/serinv/{block_primitive => }/cupyfix_backends/cupy_blas.h | 0 src/serinv/{block_primitive => }/cupyfix_backends/cupy_complex.h | 0 .../{block_primitive => }/cupyfix_backends/hip/cupy_cuComplex.h | 0 .../{block_primitive => }/cupyfix_backends/hip/cupy_hip_common.h | 0 .../{block_primitive => }/cupyfix_backends/hip/cupy_hipblas.h | 0 .../{block_primitive => }/cupyfix_backends/stub/cupy_cuComplex.h | 0 .../{block_primitive => }/cupyfix_backends/stub/cupy_cublas.h | 0 16 files changed, 0 insertions(+), 0 deletions(-) rename src/serinv/{block_primitive => }/cupyfix_backends/__init__.pxd (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/__init__.py (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/cuda/__init__.pxd (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/cuda/__init__.py (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/cuda/cupy_cublas.h (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/cuda/libs/__init.pxd (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/cuda/libs/__init__.py (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/cuda/libs/cublas.pxd (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/cuda/libs/cublas.pyx (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/cupy_blas.h (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/cupy_complex.h (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/hip/cupy_cuComplex.h (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/hip/cupy_hip_common.h (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/hip/cupy_hipblas.h (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/stub/cupy_cuComplex.h (100%) rename src/serinv/{block_primitive => }/cupyfix_backends/stub/cupy_cublas.h (100%) diff --git a/src/serinv/block_primitive/cupyfix_backends/__init__.pxd b/src/serinv/cupyfix_backends/__init__.pxd similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/__init__.pxd rename to src/serinv/cupyfix_backends/__init__.pxd diff --git a/src/serinv/block_primitive/cupyfix_backends/__init__.py b/src/serinv/cupyfix_backends/__init__.py similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/__init__.py rename to src/serinv/cupyfix_backends/__init__.py diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/__init__.pxd b/src/serinv/cupyfix_backends/cuda/__init__.pxd similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cuda/__init__.pxd rename to src/serinv/cupyfix_backends/cuda/__init__.pxd diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/__init__.py b/src/serinv/cupyfix_backends/cuda/__init__.py similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cuda/__init__.py rename to src/serinv/cupyfix_backends/cuda/__init__.py diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/cupy_cublas.h b/src/serinv/cupyfix_backends/cuda/cupy_cublas.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cuda/cupy_cublas.h rename to src/serinv/cupyfix_backends/cuda/cupy_cublas.h diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init.pxd b/src/serinv/cupyfix_backends/cuda/libs/__init.pxd similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init.pxd rename to src/serinv/cupyfix_backends/cuda/libs/__init.pxd diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init__.py b/src/serinv/cupyfix_backends/cuda/libs/__init__.py similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cuda/libs/__init__.py rename to src/serinv/cupyfix_backends/cuda/libs/__init__.py diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pxd b/src/serinv/cupyfix_backends/cuda/libs/cublas.pxd similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pxd rename to src/serinv/cupyfix_backends/cuda/libs/cublas.pxd diff --git a/src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cuda/libs/cublas.pyx rename to src/serinv/cupyfix_backends/cuda/libs/cublas.pyx diff --git a/src/serinv/block_primitive/cupyfix_backends/cupy_blas.h b/src/serinv/cupyfix_backends/cupy_blas.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cupy_blas.h rename to src/serinv/cupyfix_backends/cupy_blas.h diff --git a/src/serinv/block_primitive/cupyfix_backends/cupy_complex.h b/src/serinv/cupyfix_backends/cupy_complex.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/cupy_complex.h rename to src/serinv/cupyfix_backends/cupy_complex.h diff --git a/src/serinv/block_primitive/cupyfix_backends/hip/cupy_cuComplex.h b/src/serinv/cupyfix_backends/hip/cupy_cuComplex.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/hip/cupy_cuComplex.h rename to src/serinv/cupyfix_backends/hip/cupy_cuComplex.h diff --git a/src/serinv/block_primitive/cupyfix_backends/hip/cupy_hip_common.h b/src/serinv/cupyfix_backends/hip/cupy_hip_common.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/hip/cupy_hip_common.h rename to src/serinv/cupyfix_backends/hip/cupy_hip_common.h diff --git a/src/serinv/block_primitive/cupyfix_backends/hip/cupy_hipblas.h b/src/serinv/cupyfix_backends/hip/cupy_hipblas.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/hip/cupy_hipblas.h rename to src/serinv/cupyfix_backends/hip/cupy_hipblas.h diff --git a/src/serinv/block_primitive/cupyfix_backends/stub/cupy_cuComplex.h b/src/serinv/cupyfix_backends/stub/cupy_cuComplex.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/stub/cupy_cuComplex.h rename to src/serinv/cupyfix_backends/stub/cupy_cuComplex.h diff --git a/src/serinv/block_primitive/cupyfix_backends/stub/cupy_cublas.h b/src/serinv/cupyfix_backends/stub/cupy_cublas.h similarity index 100% rename from src/serinv/block_primitive/cupyfix_backends/stub/cupy_cublas.h rename to src/serinv/cupyfix_backends/stub/cupy_cublas.h From 6eef7f63418ec0f4916f39d5f644ab1bc79e2314 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 13:47:06 +0000 Subject: [PATCH 121/157] attempt to fix install --- setup.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 669754e6..574d7671 100644 --- a/setup.py +++ b/setup.py @@ -3,15 +3,15 @@ ext = Extension( name="cupyfix_backends.cuda.libs.cublas", - sources=["cupyfix_backends/cuda/libs/cublas.pyx", - "cupyfix_backends/cuda/cupy_cublas.h", - "cupyfix_backends/cuda/hip/cupy_cuComplex.h", - "cupyfix_backends/cuda/hip/cupy_hip_common.h", - "cupyfix_backends/cuda/hip/cupy_hipblas.h", - "cupyfix_backends/cuda/stub/cupy_cublas.h", - "cupyfix_backends/cuda/stub/cupy_cuComplex.h", - "cupyfix_backends/cuda/cupy_blas.h" - "cupyfix_backends/cuda/cupy_complex.h"], + sources=["~/cupyfix_backends/cuda/libs/cublas.pyx", + "~/cupyfix_backends/cuda/cupy_cublas.h", + "~/cupyfix_backends/cuda/hip/cupy_cuComplex.h", + "~/cupyfix_backends/cuda/hip/cupy_hip_common.h", + "~/cupyfix_backends/cuda/hip/cupy_hipblas.h", + "~/cupyfix_backends/cuda/stub/cupy_cublas.h", + "~/cupyfix_backends/cuda/stub/cupy_cuComplex.h", + "~/cupyfix_backends/cuda/cupy_blas.h" + "~/cupyfix_backends/cuda/cupy_complex.h"], include_dirs=["cupyfix_backends"], ) From 9b677f287a0426806e6e0853a0f3127d1d31312e Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 13:48:08 +0000 Subject: [PATCH 122/157] fixing file path --- setup.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 574d7671..3706f55e 100644 --- a/setup.py +++ b/setup.py @@ -3,15 +3,15 @@ ext = Extension( name="cupyfix_backends.cuda.libs.cublas", - sources=["~/cupyfix_backends/cuda/libs/cublas.pyx", - "~/cupyfix_backends/cuda/cupy_cublas.h", - "~/cupyfix_backends/cuda/hip/cupy_cuComplex.h", - "~/cupyfix_backends/cuda/hip/cupy_hip_common.h", - "~/cupyfix_backends/cuda/hip/cupy_hipblas.h", - "~/cupyfix_backends/cuda/stub/cupy_cublas.h", - "~/cupyfix_backends/cuda/stub/cupy_cuComplex.h", - "~/cupyfix_backends/cuda/cupy_blas.h" - "~/cupyfix_backends/cuda/cupy_complex.h"], + sources=["serinv/cupyfix_backends/cuda/libs/cublas.pyx", + "serinv/cupyfix_backends/cuda/cupy_cublas.h", + "serinv/cupyfix_backends/cuda/hip/cupy_cuComplex.h", + "serinv/cupyfix_backends/cuda/hip/cupy_hip_common.h", + "serinv/cupyfix_backends/cuda/hip/cupy_hipblas.h", + "serinv/cupyfix_backends/cuda/stub/cupy_cublas.h", + "serinv/cupyfix_backends/cuda/stub/cupy_cuComplex.h", + "serinv/cupyfix_backends/cuda/cupy_blas.h" + "serinv/cupyfix_backends/cuda/cupy_complex.h"], include_dirs=["cupyfix_backends"], ) From 1047aa2048ccdab03f1b579117790fe3e1ec65e8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 13:50:56 +0000 Subject: [PATCH 123/157] more path fixing --- setup.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 3706f55e..45dbfcfd 100644 --- a/setup.py +++ b/setup.py @@ -3,15 +3,15 @@ ext = Extension( name="cupyfix_backends.cuda.libs.cublas", - sources=["serinv/cupyfix_backends/cuda/libs/cublas.pyx", - "serinv/cupyfix_backends/cuda/cupy_cublas.h", - "serinv/cupyfix_backends/cuda/hip/cupy_cuComplex.h", - "serinv/cupyfix_backends/cuda/hip/cupy_hip_common.h", - "serinv/cupyfix_backends/cuda/hip/cupy_hipblas.h", - "serinv/cupyfix_backends/cuda/stub/cupy_cublas.h", - "serinv/cupyfix_backends/cuda/stub/cupy_cuComplex.h", - "serinv/cupyfix_backends/cuda/cupy_blas.h" - "serinv/cupyfix_backends/cuda/cupy_complex.h"], + sources=["src/serinv/cupyfix_backends/cuda/libs/cublas.pyx", + "src/serinv/cupyfix_backends/cuda/cupy_cublas.h", + "src/serinv/cupyfix_backends/cuda/hip/cupy_cuComplex.h", + "src/serinv/cupyfix_backends/cuda/hip/cupy_hip_common.h", + "src/serinv/cupyfix_backends/cuda/hip/cupy_hipblas.h", + "src/serinv/cupyfix_backends/cuda/stub/cupy_cublas.h", + "src/serinv/cupyfix_backends/cuda/stub/cupy_cuComplex.h", + "src/serinv/cupyfix_backends/cuda/cupy_blas.h" + "src/serinv/cupyfix_backends/cuda/cupy_complex.h"], include_dirs=["cupyfix_backends"], ) From aca1e0d6a1428d35626f40b6241224bc8629c15c Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 13:59:29 +0000 Subject: [PATCH 124/157] changed imports --- src/serinv/cupyfix_backends/cuda/api/__init.pxd | 0 src/serinv/cupyfix_backends/cuda/api/__init__.py | 0 src/serinv/cupyfix_backends/cuda/libs/cublas.pyx | 4 ++-- 3 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 src/serinv/cupyfix_backends/cuda/api/__init.pxd create mode 100644 src/serinv/cupyfix_backends/cuda/api/__init__.py diff --git a/src/serinv/cupyfix_backends/cuda/api/__init.pxd b/src/serinv/cupyfix_backends/cuda/api/__init.pxd new file mode 100644 index 00000000..e69de29b diff --git a/src/serinv/cupyfix_backends/cuda/api/__init__.py b/src/serinv/cupyfix_backends/cuda/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx index f406b3a8..1782238f 100644 --- a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx +++ b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx @@ -4,8 +4,8 @@ cimport cython # NOQA -from cupy_backends.cuda.api cimport runtime -from cupy_backends.cuda cimport stream as stream_module +from cupy_backends.cuda.api import runtime +from cupy_backends.cuda import stream as stream_module ############################################################################### # Extern From bf6d0f354c4271fc83e6892855149b4023f05753 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 14:01:44 +0000 Subject: [PATCH 125/157] included missing file --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 45dbfcfd..17766f5f 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ ext = Extension( name="cupyfix_backends.cuda.libs.cublas", sources=["src/serinv/cupyfix_backends/cuda/libs/cublas.pyx", + "src/serinv/cupyfix_backends/cuda/libs/cublas.pxd", "src/serinv/cupyfix_backends/cuda/cupy_cublas.h", "src/serinv/cupyfix_backends/cuda/hip/cupy_cuComplex.h", "src/serinv/cupyfix_backends/cuda/hip/cupy_hip_common.h", From eadbd480676ebcb7d656b2755545d3e9110c5e71 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 14:03:15 +0000 Subject: [PATCH 126/157] changed source order --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 17766f5f..a3b30769 100644 --- a/setup.py +++ b/setup.py @@ -3,8 +3,8 @@ ext = Extension( name="cupyfix_backends.cuda.libs.cublas", - sources=["src/serinv/cupyfix_backends/cuda/libs/cublas.pyx", - "src/serinv/cupyfix_backends/cuda/libs/cublas.pxd", + sources=["src/serinv/cupyfix_backends/cuda/libs/cublas.pxd", + "src/serinv/cupyfix_backends/cuda/libs/cublas.pyx", "src/serinv/cupyfix_backends/cuda/cupy_cublas.h", "src/serinv/cupyfix_backends/cuda/hip/cupy_cuComplex.h", "src/serinv/cupyfix_backends/cuda/hip/cupy_hip_common.h", From 525dd1244c3ce18da9b29a421d7a856965ce1628 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 14:08:30 +0000 Subject: [PATCH 127/157] added context declarations --- src/serinv/cupyfix_backends/cuda/libs/cublas.pyx | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx index 1782238f..c312490e 100644 --- a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx +++ b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx @@ -22,6 +22,21 @@ cdef extern from '../../cupy_blas.h' nogil: ctypedef void* Stream 'cudaStream_t' ctypedef int DataType 'cudaDataType' + # Context + int cublasCreate(Handle* handle) + int cublasDestroy(Handle handle) + int cublasGetVersion(Handle handle, int* version) + int cublasGetPointerMode(Handle handle, PointerMode* mode) + int cublasSetPointerMode(Handle handle, PointerMode mode) + + # Stream + int cublasSetStream(Handle handle, Stream streamId) + int cublasGetStream(Handle handle, Stream* streamId) + + # Math Mode + int cublasSetMathMode(Handle handle, Math mode) + int cublasGetMathMode(Handle handle, Math* mode) + # BLAS Level 3 int cublasCherk( Handle handle, FillMode uplo, Operation trans, int n, int k, From fdd63616f099f0013f030a7363a42c6d4f537db5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 14:10:58 +0000 Subject: [PATCH 128/157] cimporting cublas --- src/serinv/cupyfix_backends/cuda/libs/cublas.pyx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx index c312490e..c9c90f44 100644 --- a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx +++ b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx @@ -4,6 +4,8 @@ cimport cython # NOQA +cimport cublas + from cupy_backends.cuda.api import runtime from cupy_backends.cuda import stream as stream_module From 13aee99331859c43048a5d91de2352437ec8a2bc Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 14:13:17 +0000 Subject: [PATCH 129/157] removed context --- src/serinv/cupyfix_backends/cuda/libs/cublas.pyx | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx index c9c90f44..6bca48ec 100644 --- a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx +++ b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx @@ -24,21 +24,6 @@ cdef extern from '../../cupy_blas.h' nogil: ctypedef void* Stream 'cudaStream_t' ctypedef int DataType 'cudaDataType' - # Context - int cublasCreate(Handle* handle) - int cublasDestroy(Handle handle) - int cublasGetVersion(Handle handle, int* version) - int cublasGetPointerMode(Handle handle, PointerMode* mode) - int cublasSetPointerMode(Handle handle, PointerMode mode) - - # Stream - int cublasSetStream(Handle handle, Stream streamId) - int cublasGetStream(Handle handle, Stream* streamId) - - # Math Mode - int cublasSetMathMode(Handle handle, Math mode) - int cublasGetMathMode(Handle handle, Math* mode) - # BLAS Level 3 int cublasCherk( Handle handle, FillMode uplo, Operation trans, int n, int k, From e605747801cd79e10606f0067c5b94beb1dd3b18 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 14:16:20 +0000 Subject: [PATCH 130/157] try to correctly import cublas.pxd --- src/serinv/cupyfix_backends/cuda/libs/cublas.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx index 6bca48ec..74bb2a23 100644 --- a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx +++ b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx @@ -4,7 +4,7 @@ cimport cython # NOQA -cimport cublas +from cupyfix_backends.cuda.libs cimport cublas from cupy_backends.cuda.api import runtime from cupy_backends.cuda import stream as stream_module From 2918d17914c3e5a85f212862294ad9ec6d6294f6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 14:17:15 +0000 Subject: [PATCH 131/157] import fix --- src/serinv/cupyfix_backends/cuda/libs/cublas.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx index 74bb2a23..b6f33d4a 100644 --- a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx +++ b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx @@ -4,7 +4,7 @@ cimport cython # NOQA -from cupyfix_backends.cuda.libs cimport cublas +cimport cupyfix_backends.cuda.libs.cublas as cublas from cupy_backends.cuda.api import runtime from cupy_backends.cuda import stream as stream_module From d497ffcec0be770954034026322467c3d0ed8a33 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 15:02:03 +0000 Subject: [PATCH 132/157] moving the typedef --- src/serinv/cupyfix_backends/cuda/api/__init.pxd | 0 src/serinv/cupyfix_backends/cuda/api/__init__.py | 0 src/serinv/cupyfix_backends/cuda/libs/cublas.pyx | 14 +++++++++++++- 3 files changed, 13 insertions(+), 1 deletion(-) delete mode 100644 src/serinv/cupyfix_backends/cuda/api/__init.pxd delete mode 100644 src/serinv/cupyfix_backends/cuda/api/__init__.py diff --git a/src/serinv/cupyfix_backends/cuda/api/__init.pxd b/src/serinv/cupyfix_backends/cuda/api/__init.pxd deleted file mode 100644 index e69de29b..00000000 diff --git a/src/serinv/cupyfix_backends/cuda/api/__init__.py b/src/serinv/cupyfix_backends/cuda/api/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx index b6f33d4a..fc1cfc1d 100644 --- a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx +++ b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx @@ -4,11 +4,23 @@ cimport cython # NOQA -cimport cupyfix_backends.cuda.libs.cublas as cublas + from cupy_backends.cuda.api import runtime from cupy_backends.cuda import stream as stream_module +cdef: + ctypedef void* Handle 'cublasHandle_t' + + ctypedef int DiagType 'cublasDiagType_t' + ctypedef int FillMode 'cublasFillMode_t' + ctypedef int Operation 'cublasOperation_t' + ctypedef int PointerMode 'cublasPointerMode_t' + ctypedef int SideMode 'cublasSideMode_t' + ctypedef int GemmAlgo 'cublasGemmAlgo_t' + ctypedef int Math 'cublasMath_t' + ctypedef int ComputeType 'cublasComputeType_t' + ############################################################################### # Extern ############################################################################### From c4054f0632d7a7c1eaf6307145b2aab4f343c92c Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 15:03:53 +0000 Subject: [PATCH 133/157] import intptr --- src/serinv/cupyfix_backends/cuda/libs/cublas.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx index fc1cfc1d..bc31e030 100644 --- a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx +++ b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx @@ -4,7 +4,7 @@ cimport cython # NOQA - +from libc.stdint cimport intptr_t from cupy_backends.cuda.api import runtime from cupy_backends.cuda import stream as stream_module From 30b61f670ac46d2be7ff9317611658f3cc7b933a Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 15:15:35 +0000 Subject: [PATCH 134/157] added setstream --- src/serinv/cupyfix_backends/cuda/libs/cublas.pyx | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx index bc31e030..faaf35b8 100644 --- a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx +++ b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx @@ -36,6 +36,10 @@ cdef extern from '../../cupy_blas.h' nogil: ctypedef void* Stream 'cudaStream_t' ctypedef int DataType 'cudaDataType' + # Stream + int cublasSetStream(Handle handle, Stream streamId) + int cublasGetStream(Handle handle, Stream* streamId) + # BLAS Level 3 int cublasCherk( Handle handle, FillMode uplo, Operation trans, int n, int k, From 1e886c3c1ea8dd91bc010cbfc731613461bab99e Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 15:16:31 +0000 Subject: [PATCH 135/157] added setstrea, better --- src/serinv/cupyfix_backends/cuda/libs/cublas.pyx | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx index faaf35b8..445425d7 100644 --- a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx +++ b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx @@ -103,6 +103,21 @@ cpdef inline check_status(int status): raise CUBLASError(status) +cpdef setStream(intptr_t handle, size_t stream): + # TODO(leofang): It seems most of cuBLAS APIs support stream capture (as of + # CUDA 11.5) under certain conditions, see + # https://docs.nvidia.com/cuda/cublas/index.html#CUDA-graphs + # Before we come up with a robust strategy to test the support conditions, + # we disable this functionality. + if not runtime._is_hip_environment and runtime.streamIsCapturing(stream): + raise NotImplementedError( + 'calling cuBLAS API during stream capture is currently ' + 'unsupported') + + with nogil: + status = cublasSetStream(handle, stream) + check_status(status) + ############################################################################### # BLAS Level 3 From 124c41f1559a84edb0e5126b4ced6994a42fac1a Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 15:16:59 +0000 Subject: [PATCH 136/157] added _setstream --- src/serinv/cupyfix_backends/cuda/libs/cublas.pyx | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx index 445425d7..704ea943 100644 --- a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx +++ b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx @@ -118,6 +118,9 @@ cpdef setStream(intptr_t handle, size_t stream): status = cublasSetStream(handle, stream) check_status(status) +cdef _setStream(intptr_t handle): + """Set current stream""" + setStream(handle, stream_module.get_current_stream_ptr()) ############################################################################### # BLAS Level 3 From dc99cbde268c13ece49c8c3742c1d14da2bca92d Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 15:18:08 +0000 Subject: [PATCH 137/157] changed path --- src/serinv/block_primitive/syherk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 9bd6f8ef..0ad0fdcb 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,7 +1,7 @@ from serinv import _get_module_from_array import serinv.block_primitive.trymod -from serinv.block_primitive.cupyfix_backends.cuda.libs import cublas as cublasfix +from serinv.cupyfix_backends.cuda.libs import cublas as cublasfix import numpy as np from numpy.linalg import matmul From ca620c9e722be42b11ff41aa83c70dfcdd6d73cd Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 22:21:54 +0000 Subject: [PATCH 138/157] renamed package --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index a3b30769..60d2d125 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ ) setup( - name="cupyfix_backends", + name="cublas", ext_modules=cythonize([ext]), - packages=["cupyfix_backend"], + packages=["cublas"], ) \ No newline at end of file From 0ff2ddc23a3ef75752c7df6d1d51a362091597f9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 22:22:54 +0000 Subject: [PATCH 139/157] renamedpackage again --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 60d2d125..be011ab2 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ ) setup( - name="cublas", + name="cuupyfix_backends", ext_modules=cythonize([ext]), packages=["cublas"], ) \ No newline at end of file From 5ac8805a72a6e06883d3bd5212aea8a37fa20bc8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 22:23:26 +0000 Subject: [PATCH 140/157] hopefully final fix for name --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index be011ab2..e1c4ffd8 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ ) setup( - name="cuupyfix_backends", + name="cupyfix_backends", ext_modules=cythonize([ext]), - packages=["cublas"], + packages=["cupyfix_backends"], ) \ No newline at end of file From a9b771b49c39c39748e1de50fb099c248c5d2e1b Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 22:25:04 +0000 Subject: [PATCH 141/157] further path fix --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e1c4ffd8..fd44ba8c 100644 --- a/setup.py +++ b/setup.py @@ -19,5 +19,5 @@ setup( name="cupyfix_backends", ext_modules=cythonize([ext]), - packages=["cupyfix_backends"], + packages=["src/serinv/cupyfix_backends"], ) \ No newline at end of file From cff36424a6d492c412b61e742e2f6b4c07fc1c8d Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 22:27:11 +0000 Subject: [PATCH 142/157] further renaming --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fd44ba8c..6f4d73e3 100644 --- a/setup.py +++ b/setup.py @@ -19,5 +19,5 @@ setup( name="cupyfix_backends", ext_modules=cythonize([ext]), - packages=["src/serinv/cupyfix_backends"], + packages=["src/serinv/cupyfix_backends.cuda.libs.cublas"], ) \ No newline at end of file From cd6f2a9c811d89d3b365c6e75862ffaeac6ec52f Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 22:28:41 +0000 Subject: [PATCH 143/157] continuous renaming --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6f4d73e3..77aec0fc 100644 --- a/setup.py +++ b/setup.py @@ -19,5 +19,5 @@ setup( name="cupyfix_backends", ext_modules=cythonize([ext]), - packages=["src/serinv/cupyfix_backends.cuda.libs.cublas"], + packages=["src/serinv/cupyfix_backends.cuda.libs"], ) \ No newline at end of file From 35c5c07ea021abc83fb133a3d3bb97c95b23e432 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 22:34:58 +0000 Subject: [PATCH 144/157] changed loaction of.h files --- setup.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 77aec0fc..58357c25 100644 --- a/setup.py +++ b/setup.py @@ -3,17 +3,18 @@ ext = Extension( name="cupyfix_backends.cuda.libs.cublas", - sources=["src/serinv/cupyfix_backends/cuda/libs/cublas.pxd", - "src/serinv/cupyfix_backends/cuda/libs/cublas.pyx", - "src/serinv/cupyfix_backends/cuda/cupy_cublas.h", - "src/serinv/cupyfix_backends/cuda/hip/cupy_cuComplex.h", - "src/serinv/cupyfix_backends/cuda/hip/cupy_hip_common.h", - "src/serinv/cupyfix_backends/cuda/hip/cupy_hipblas.h", - "src/serinv/cupyfix_backends/cuda/stub/cupy_cublas.h", - "src/serinv/cupyfix_backends/cuda/stub/cupy_cuComplex.h", - "src/serinv/cupyfix_backends/cuda/cupy_blas.h" - "src/serinv/cupyfix_backends/cuda/cupy_complex.h"], - include_dirs=["cupyfix_backends"], + sources=[ + "src/serinv/cupyfix_backends/cuda/libs/cublas.pyx"], + include_dirs=["src/serinv/cupyfix_backends/cuda/libs/cublas.pxd", + "src/serinv/cupyfix_backends/cuda/cupy_cublas.h", + "src/serinv/cupyfix_backends/cuda/hip/cupy_cuComplex.h", + "src/serinv/cupyfix_backends/cuda/hip/cupy_hip_common.h", + "src/serinv/cupyfix_backends/cuda/hip/cupy_hipblas.h", + "src/serinv/cupyfix_backends/cuda/stub/cupy_cublas.h", + "src/serinv/cupyfix_backends/cuda/stub/cupy_cuComplex.h", + "src/serinv/cupyfix_backends/cuda/cupy_blas.h", + "src/serinv/cupyfix_backends/cuda/cupy_complex.h" + ], ) setup( From 820e1b60730ccf46dcba52581d1228dd00d80e01 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 22:39:10 +0000 Subject: [PATCH 145/157] changed dirs --- setup.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index 58357c25..70ea207c 100644 --- a/setup.py +++ b/setup.py @@ -5,16 +5,11 @@ name="cupyfix_backends.cuda.libs.cublas", sources=[ "src/serinv/cupyfix_backends/cuda/libs/cublas.pyx"], - include_dirs=["src/serinv/cupyfix_backends/cuda/libs/cublas.pxd", - "src/serinv/cupyfix_backends/cuda/cupy_cublas.h", - "src/serinv/cupyfix_backends/cuda/hip/cupy_cuComplex.h", - "src/serinv/cupyfix_backends/cuda/hip/cupy_hip_common.h", - "src/serinv/cupyfix_backends/cuda/hip/cupy_hipblas.h", - "src/serinv/cupyfix_backends/cuda/stub/cupy_cublas.h", - "src/serinv/cupyfix_backends/cuda/stub/cupy_cuComplex.h", - "src/serinv/cupyfix_backends/cuda/cupy_blas.h", - "src/serinv/cupyfix_backends/cuda/cupy_complex.h" - ], + include_dirs=["cupyfix_backends/cuda/libs", + "cupyfix_backends/hip", + "cupyfix_backends/stub", + "cupyfix_backends/cuda", + "cupyfix_backends"], ) setup( From 811517789fce38678883331a32df870f1e4f5396 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 22:45:22 +0000 Subject: [PATCH 146/157] bypass include conditions --- src/serinv/cupyfix_backends/cupy_complex.h | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/serinv/cupyfix_backends/cupy_complex.h b/src/serinv/cupyfix_backends/cupy_complex.h index 5c7efed9..e4ea29e5 100644 --- a/src/serinv/cupyfix_backends/cupy_complex.h +++ b/src/serinv/cupyfix_backends/cupy_complex.h @@ -1,17 +1,7 @@ #ifndef INCLUDE_GUARD_CUPY_COMPLEX_H #define INCLUDE_GUARD_CUPY_COMPLEX_H -#ifdef CUPY_USE_HIP -#include "hip/cupy_cuComplex.h" - -#elif !defined(CUPY_NO_CUDA) - -#include -#else // #if !defined(CUPY_NO_CUDA) || !defined(CUPY_USE_HIP) - -#include "stub/cupy_cuComplex.h" +#include "hip/cupy_cuComplex.h" -#endif // #ifndef CUPY_NO_CUDA -#endif // #ifndef INCLUDE_GUARD_CUPY_COMPLEX_H \ No newline at end of file From 04610cadc1125aae90318ddf2686eab7f12a4810 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 22:45:59 +0000 Subject: [PATCH 147/157] fixed missing endif --- src/serinv/cupyfix_backends/cupy_complex.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/cupyfix_backends/cupy_complex.h b/src/serinv/cupyfix_backends/cupy_complex.h index e4ea29e5..90c77337 100644 --- a/src/serinv/cupyfix_backends/cupy_complex.h +++ b/src/serinv/cupyfix_backends/cupy_complex.h @@ -5,3 +5,5 @@ #include "hip/cupy_cuComplex.h" + +#endif // #ifndef INCLUDE_GUARD_CUPY_COMPLEX_H \ No newline at end of file From b3119adbe331bfb104c2333e428a68d18508d71a Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 22:49:10 +0000 Subject: [PATCH 148/157] tmore workarounds for possibly unneccessary includes --- .../cupyfix_backends/cuda/cupy_cublas.h | 24 ------------------- src/serinv/cupyfix_backends/cupy_blas.h | 10 -------- 2 files changed, 34 deletions(-) delete mode 100644 src/serinv/cupyfix_backends/cuda/cupy_cublas.h diff --git a/src/serinv/cupyfix_backends/cuda/cupy_cublas.h b/src/serinv/cupyfix_backends/cuda/cupy_cublas.h deleted file mode 100644 index c3d4874b..00000000 --- a/src/serinv/cupyfix_backends/cuda/cupy_cublas.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef INCLUDE_GUARD_CUDA_CUPY_CUBLAS_H -#define INCLUDE_GUARD_CUDA_CUPY_CUBLAS_H - -#include -#include - -#if CUDA_VERSION >= 11000 - -#define cublasGemmEx_v11 cublasGemmEx -#define cublasGemmStridedBatchedEx_v11 cublasGemmStridedBatchedEx - -#else - -typedef enum{} cublasComputeType_t; -cublasStatus_t cublasGemmEx_v11(...) { - return CUBLAS_STATUS_NOT_SUPPORTED; -} -cublasStatus_t cublasGemmStridedBatchedEx_v11(...) { - return CUBLAS_STATUS_NOT_SUPPORTED; -} - -#endif // if CUDA_VERSION >= 11000 - -#endif // #ifndef INCLUDE_GUARD_CUDA_CUPY_CUBLAS_H \ No newline at end of file diff --git a/src/serinv/cupyfix_backends/cupy_blas.h b/src/serinv/cupyfix_backends/cupy_blas.h index ad2cffa5..8e7401a6 100644 --- a/src/serinv/cupyfix_backends/cupy_blas.h +++ b/src/serinv/cupyfix_backends/cupy_blas.h @@ -1,17 +1,7 @@ #ifndef INCLUDE_GUARD_CUPY_CUBLAS_H #define INCLUDE_GUARD_CUPY_CUBLAS_H -#if CUPY_USE_HIP - #include "hip/cupy_hipblas.h" -#elif !defined(CUPY_NO_CUDA) - -#include "cuda/cupy_cublas.h" - -#else // #ifndef CUPY_NO_CUDA - -#include "stub/cupy_cublas.h" -#endif // #ifndef CUPY_NO_CUDA #endif // #ifndef INCLUDE_GUARD_CUPY_CUBLAS_H \ No newline at end of file From 5398b4a59cd1a8683bdf93da728ae0cc53c92a57 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 23:29:46 +0000 Subject: [PATCH 149/157] changed include to cuda --- .../cupyfix_backends/cuda/cupy_cublas.h | 24 +++++++++++++++++++ src/serinv/cupyfix_backends/cupy_blas.h | 2 +- 2 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 src/serinv/cupyfix_backends/cuda/cupy_cublas.h diff --git a/src/serinv/cupyfix_backends/cuda/cupy_cublas.h b/src/serinv/cupyfix_backends/cuda/cupy_cublas.h new file mode 100644 index 00000000..c3d4874b --- /dev/null +++ b/src/serinv/cupyfix_backends/cuda/cupy_cublas.h @@ -0,0 +1,24 @@ +#ifndef INCLUDE_GUARD_CUDA_CUPY_CUBLAS_H +#define INCLUDE_GUARD_CUDA_CUPY_CUBLAS_H + +#include +#include + +#if CUDA_VERSION >= 11000 + +#define cublasGemmEx_v11 cublasGemmEx +#define cublasGemmStridedBatchedEx_v11 cublasGemmStridedBatchedEx + +#else + +typedef enum{} cublasComputeType_t; +cublasStatus_t cublasGemmEx_v11(...) { + return CUBLAS_STATUS_NOT_SUPPORTED; +} +cublasStatus_t cublasGemmStridedBatchedEx_v11(...) { + return CUBLAS_STATUS_NOT_SUPPORTED; +} + +#endif // if CUDA_VERSION >= 11000 + +#endif // #ifndef INCLUDE_GUARD_CUDA_CUPY_CUBLAS_H \ No newline at end of file diff --git a/src/serinv/cupyfix_backends/cupy_blas.h b/src/serinv/cupyfix_backends/cupy_blas.h index 8e7401a6..0f9006df 100644 --- a/src/serinv/cupyfix_backends/cupy_blas.h +++ b/src/serinv/cupyfix_backends/cupy_blas.h @@ -1,7 +1,7 @@ #ifndef INCLUDE_GUARD_CUPY_CUBLAS_H #define INCLUDE_GUARD_CUPY_CUBLAS_H -#include "hip/cupy_hipblas.h" +#include "cuda/cupy_cublas.h" #endif // #ifndef INCLUDE_GUARD_CUPY_CUBLAS_H \ No newline at end of file From ec1a42c817912aba31e15b9095d87010f725d084 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 06:26:57 +0000 Subject: [PATCH 150/157] include cuda --- setup.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 70ea207c..71a741fd 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,11 @@ from setuptools import setup, Extension from Cython.Build import cythonize +import os + +CONDA_PREFIX = os.environ.get("CONDA_PREFIX", "") +CUDA_INCLUDE = os.path.join(CONDA_PREFIX, "include") + + ext = Extension( name="cupyfix_backends.cuda.libs.cublas", @@ -9,7 +15,8 @@ "cupyfix_backends/hip", "cupyfix_backends/stub", "cupyfix_backends/cuda", - "cupyfix_backends"], + "cupyfix_backends", + CUDA_INCLUDE], ) setup( From ff95437d2e525cb6116f7eba9aabeddb45e3e3d0 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 06:29:55 +0000 Subject: [PATCH 151/157] include cuda lib --- setup.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index 71a741fd..7c47b0db 100644 --- a/setup.py +++ b/setup.py @@ -4,19 +4,23 @@ CONDA_PREFIX = os.environ.get("CONDA_PREFIX", "") CUDA_INCLUDE = os.path.join(CONDA_PREFIX, "include") - +CUDA_LIB = os.path.join(CONDA_PREFIX, "lib") ext = Extension( name="cupyfix_backends.cuda.libs.cublas", sources=[ - "src/serinv/cupyfix_backends/cuda/libs/cublas.pyx"], - include_dirs=["cupyfix_backends/cuda/libs", - "cupyfix_backends/hip", - "cupyfix_backends/stub", - "cupyfix_backends/cuda", - "cupyfix_backends", - CUDA_INCLUDE], + "src/serinv/cupyfix_backends/cuda/libs/cublas.pyx" + ], + include_dirs=["cupyfix_backends/cuda/libs", + "cupyfix_backends/hip", + "cupyfix_backends/stub", + "cupyfix_backends/cuda", + "cupyfix_backends", + CUDA_INCLUDE + ], + library_dirs=[CUDA_LIB], + ) setup( From d563cf41e60156d9e87c11b911df2d229f2dd4cf Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 06:45:22 +0000 Subject: [PATCH 152/157] include hopefully correct path --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 7c47b0db..228d72f0 100644 --- a/setup.py +++ b/setup.py @@ -3,8 +3,9 @@ import os CONDA_PREFIX = os.environ.get("CONDA_PREFIX", "") -CUDA_INCLUDE = os.path.join(CONDA_PREFIX, "include") -CUDA_LIB = os.path.join(CONDA_PREFIX, "lib") +CONDA_PREFIX = os.path.join() +CUDA_INCLUDE = os.path.join(CONDA_PREFIX, "targets", "x86_64-linux", "include") + ext = Extension( @@ -19,7 +20,6 @@ "cupyfix_backends", CUDA_INCLUDE ], - library_dirs=[CUDA_LIB], ) From 51a85d7cfe3c0cac7512f8f467da688a022600d8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 06:46:05 +0000 Subject: [PATCH 153/157] removed empty os.join --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 228d72f0..931de405 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import os CONDA_PREFIX = os.environ.get("CONDA_PREFIX", "") -CONDA_PREFIX = os.path.join() + CUDA_INCLUDE = os.path.join(CONDA_PREFIX, "targets", "x86_64-linux", "include") From 2d8bce7c0bb1a1c96d5de624494957a84275e3ac Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 06:48:12 +0000 Subject: [PATCH 154/157] changed imported header --- src/serinv/cupyfix_backends/cupy_complex.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/cupyfix_backends/cupy_complex.h b/src/serinv/cupyfix_backends/cupy_complex.h index 90c77337..09589cde 100644 --- a/src/serinv/cupyfix_backends/cupy_complex.h +++ b/src/serinv/cupyfix_backends/cupy_complex.h @@ -3,7 +3,7 @@ -#include "hip/cupy_cuComplex.h" +#include #endif // #ifndef INCLUDE_GUARD_CUPY_COMPLEX_H \ No newline at end of file From cfb1d17e6ad4d734dabade785f619ac22c64c904 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 06:52:23 +0000 Subject: [PATCH 155/157] removed conflicting cdefs --- src/serinv/cupyfix_backends/cuda/libs/cublas.pyx | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx index 704ea943..c2dcf3cb 100644 --- a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx +++ b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx @@ -9,18 +9,6 @@ from libc.stdint cimport intptr_t from cupy_backends.cuda.api import runtime from cupy_backends.cuda import stream as stream_module -cdef: - ctypedef void* Handle 'cublasHandle_t' - - ctypedef int DiagType 'cublasDiagType_t' - ctypedef int FillMode 'cublasFillMode_t' - ctypedef int Operation 'cublasOperation_t' - ctypedef int PointerMode 'cublasPointerMode_t' - ctypedef int SideMode 'cublasSideMode_t' - ctypedef int GemmAlgo 'cublasGemmAlgo_t' - ctypedef int Math 'cublasMath_t' - ctypedef int ComputeType 'cublasComputeType_t' - ############################################################################### # Extern ############################################################################### From 2d92e232f8770462ee1cdf18dd63f606ff56c55e Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 07:40:08 +0000 Subject: [PATCH 156/157] removed call to attempt of selfmade cuda api --- src/serinv/block_primitive/syherk.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 0ad0fdcb..d39dc2d3 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,7 +1,7 @@ from serinv import _get_module_from_array import serinv.block_primitive.trymod -from serinv.cupyfix_backends.cuda.libs import cublas as cublasfix + import numpy as np from numpy.linalg import matmul @@ -119,9 +119,9 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals elif dtype == 'd': func = cublas.dsyrk elif dtype == 'F': - func = cublasfix.cherk + func = cublas.cherk elif dtype == 'D': - func = cublasfix.zherk + func = cublas.zherk else: raise TypeError('invalid dtype') From 50f91835a09ad8946d0689b703ee642025b0f8d6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 07:44:50 +0000 Subject: [PATCH 157/157] damagge control --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 931de405..90ac7218 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +""" from setuptools import setup, Extension from Cython.Build import cythonize import os @@ -27,4 +28,5 @@ name="cupyfix_backends", ext_modules=cythonize([ext]), packages=["src/serinv/cupyfix_backends.cuda.libs"], -) \ No newline at end of file +) +""" \ No newline at end of file