From 32eea2f1a50430786f6de85a639a63d75ce30511 Mon Sep 17 00:00:00 2001 From: vincent-maillou Date: Sat, 5 Apr 2025 09:27:04 +0200 Subject: [PATCH 001/518] added timers for comm --- src/serinv/wrappers/pddbtasc.py | 9 +++++++++ src/serinv/wrappers/pddbtsc.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/src/serinv/wrappers/pddbtasc.py b/src/serinv/wrappers/pddbtasc.py index da2fcff4..7cb0374c 100644 --- a/src/serinv/wrappers/pddbtasc.py +++ b/src/serinv/wrappers/pddbtasc.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -202,12 +204,17 @@ def pddbtasc( _rhs=ddbtars.get("_rhs", None), ) + MPI.COMM_WORLD.Barrier() + tic = time.perf_counter() aggregate_ddbtars( ddbtars=ddbtars, quadratic=quadratic, comm=comm, strategy=strategy, ) + MPI.COMM_WORLD.Barrier() + toc = time.perf_counter() + elapsed = toc - tic ddbtars["A_arrow_tip_block"][:] += A_arrow_tip_initial if quadratic: @@ -226,3 +233,5 @@ def pddbtasc( ) comm.Barrier() + + return elapsed diff --git a/src/serinv/wrappers/pddbtsc.py b/src/serinv/wrappers/pddbtsc.py index fc0a3765..b846bb07 100644 --- a/src/serinv/wrappers/pddbtsc.py +++ b/src/serinv/wrappers/pddbtsc.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -163,12 +165,17 @@ def pddbtsc( _rhs=ddbtrs.get("_rhs", None), ) + MPI.COMM_WORLD.Barrier() + tic = time.perf_counter() aggregate_ddbtrs( ddbtrs=ddbtrs, quadratic=quadratic, comm=comm, strategy=strategy, ) + MPI.COMM_WORLD.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Perform Schur complement on the reduced system ddbtsc( @@ -180,3 +187,5 @@ def pddbtsc( ) comm.Barrier() + + return elapsed \ No newline at end of file From 2dc799bc8893a28eb58e5f22020ceecfc114fdaa Mon Sep 17 00:00:00 2001 From: vincent-maillou Date: Sat, 5 Apr 2025 18:13:10 +0200 Subject: [PATCH 002/518] added nccl --- src/serinv/__init__.py | 6 +- src/serinv/wrappers/ddbtars.py | 109 +++++++++++++++++-------- src/serinv/wrappers/ddbtrs.py | 75 ++++++++++++----- src/serinv/wrappers/pddbtasc.py | 8 +- src/serinv/wrappers/pddbtasci.py | 3 + src/serinv/wrappers/pddbtsc.py | 6 ++ src/serinv/wrappers/pddbtsci.py | 6 +- src/serinv/wrappers/pobtars.py | 135 ++++++++++++++++++++++++------- src/serinv/wrappers/pobtrs.py | 115 +++++++++++++++++++++----- src/serinv/wrappers/ppobtaf.py | 16 ++++ src/serinv/wrappers/ppobtas.py | 16 ++++ src/serinv/wrappers/ppobtasi.py | 4 +- src/serinv/wrappers/ppobtf.py | 16 ++++ src/serinv/wrappers/ppobts.py | 17 ++++ src/serinv/wrappers/ppobtsi.py | 4 +- 15 files changed, 426 insertions(+), 110 deletions(-) diff --git a/src/serinv/__init__.py b/src/serinv/__init__.py index ca7f21dd..bad9c860 100644 --- a/src/serinv/__init__.py +++ b/src/serinv/__init__.py @@ -163,7 +163,7 @@ def _use_nccl(comm): return False -def _get_nccl_parameters(arr, comm, op: str): +def _get_nccl_parameters(arr, comm, rank, op: str): """Get the NCCL parameters for the given operation.""" if np.iscomplexobj(arr): factor = 2 @@ -172,8 +172,8 @@ def _get_nccl_parameters(arr, comm, op: str): if backend_flags["nccl_avail"]: if op == "allgather": - count = (arr.size // comm.size) * factor - displacement = count * comm.rank * arr.dtype.itemsize + count = (arr.size // comm.size()) * factor + displacement = count * rank * (arr.dtype.itemsize // factor) elif op == "allreduce": count = arr.size * factor displacement = 0 diff --git a/src/serinv/wrappers/ddbtars.py b/src/serinv/wrappers/ddbtars.py index 46be5972..722582a6 100644 --- a/src/serinv/wrappers/ddbtars.py +++ b/src/serinv/wrappers/ddbtars.py @@ -15,9 +15,6 @@ import cupyx as cpx import cupy as cp - if backend_flags["nccl_avail"]: - from cupy.cuda import nccl - def allocate_ddbtars( A_diagonal_blocks: ArrayLike, @@ -30,7 +27,14 @@ def allocate_ddbtars( comm: MPI.Comm, strategy: str = "allgather", quadratic: bool = False, + nccl_comm: object = None, ) -> dict: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -80,7 +84,7 @@ def allocate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -144,7 +148,7 @@ def allocate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -224,8 +228,15 @@ def map_ddbtasc_to_ddbtars( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str, + nccl_comm: object = None, **kwargs, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -360,7 +371,14 @@ def aggregate_ddbtars( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -455,7 +473,7 @@ def aggregate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. if comm_rank == 0: @@ -564,11 +582,12 @@ def aggregate_ddbtars( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -576,9 +595,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -586,9 +605,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -596,9 +615,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_arrow_blocks_comm.data.ptr, count=count, @@ -606,9 +625,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_arrow_blocks_comm.data.ptr, count=count, @@ -616,17 +635,18 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_A_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_A_arrow_tip_block_comm.data.ptr, recvbuf=_A_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -662,11 +682,12 @@ def aggregate_ddbtars( ddbtars["A_upper_arrow_blocks"] = _A_upper_arrow_blocks[1:] if quadratic: - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_diagonal_blocks_comm.data.ptr, count=count, @@ -674,9 +695,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -684,9 +705,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -694,9 +715,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_arrow_blocks_comm.data.ptr, count=count, @@ -704,9 +725,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_arrow_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_arrow_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_arrow_blocks_comm.data.ptr, count=count, @@ -714,17 +735,18 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_B_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_B_arrow_tip_block_comm.data.ptr, recvbuf=_B_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_diagonal_blocks_comm, @@ -763,7 +785,7 @@ def aggregate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -787,8 +809,18 @@ def scatter_ddbtars( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() + _A_diagonal_blocks: ArrayLike = ddbtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = ddbtars.get("A_lower_diagonal_blocks", None) _A_upper_diagonal_blocks: ArrayLike = ddbtars.get("A_upper_diagonal_blocks", None) @@ -857,8 +889,15 @@ def map_ddbtars_to_ddbtasci( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/ddbtrs.py b/src/serinv/wrappers/ddbtrs.py index 15393ae9..937d8def 100644 --- a/src/serinv/wrappers/ddbtrs.py +++ b/src/serinv/wrappers/ddbtrs.py @@ -24,7 +24,14 @@ def allocate_ddbtrs( comm: MPI.Comm, strategy: str = "allgather", quadratic: bool = False, + nccl_comm: object = None, ) -> dict: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -57,7 +64,7 @@ def allocate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -98,7 +105,7 @@ def allocate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -150,8 +157,15 @@ def map_ddbtsc_to_ddbtrs( _A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -241,7 +255,14 @@ def aggregate_ddbtrs( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -304,7 +325,7 @@ def aggregate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. if comm_rank == 0: @@ -379,11 +400,12 @@ def aggregate_ddbtrs( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -391,9 +413,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -401,9 +423,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -411,6 +433,7 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -433,11 +456,12 @@ def aggregate_ddbtrs( ddbtrs["A_upper_diagonal_blocks"] = _A_upper_diagonal_blocks[1:-2] if quadratic: - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_diagonal_blocks_comm.data.ptr, count=count, @@ -445,9 +469,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -455,9 +479,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -465,6 +489,7 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_diagonal_blocks_comm, @@ -492,7 +517,7 @@ def aggregate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -510,8 +535,15 @@ def scatter_ddbtrs( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -565,8 +597,15 @@ def map_ddbtrs_to_ddbtsci( _A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/pddbtasc.py b/src/serinv/wrappers/pddbtasc.py index 7cb0374c..6c756c92 100644 --- a/src/serinv/wrappers/pddbtasc.py +++ b/src/serinv/wrappers/pddbtasc.py @@ -24,6 +24,7 @@ def pddbtasc( A_upper_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel Schur-complement of a block tridiagonal matrix. @@ -117,9 +118,10 @@ def pddbtasc( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + rhs: dict = kwargs.get("rhs", None) quadratic: bool = kwargs.get("quadratic", False) - buffers: dict = kwargs.get("buffers", None) ddbtars: dict = kwargs.get("ddbtars", None) strategy: str = kwargs.get("strategy", "allgather") @@ -202,6 +204,7 @@ def pddbtasc( quadratic=quadratic, buffers=buffers, _rhs=ddbtars.get("_rhs", None), + nccl_comm=nccl_comm, ) MPI.COMM_WORLD.Barrier() @@ -211,7 +214,10 @@ def pddbtasc( quadratic=quadratic, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() MPI.COMM_WORLD.Barrier() toc = time.perf_counter() elapsed = toc - tic diff --git a/src/serinv/wrappers/pddbtasci.py b/src/serinv/wrappers/pddbtasci.py index f86b6a9c..f235d76b 100644 --- a/src/serinv/wrappers/pddbtasci.py +++ b/src/serinv/wrappers/pddbtasci.py @@ -22,6 +22,7 @@ def pddbtasci( A_upper_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel selected-inversion of the Schur-complement of a block tridiagonal matrix. @@ -159,6 +160,7 @@ def pddbtasci( ddbtars=ddbtars, comm=comm, quadratic=quadratic, + nccl_comm=nccl_comm, ) map_ddbtars_to_ddbtasci( @@ -179,6 +181,7 @@ def pddbtasci( quadratic=quadratic, buffers=buffers, _rhs=ddbtars.get("_rhs", None), + nccl_comm=nccl_comm, ) # Perform distributed SCI diff --git a/src/serinv/wrappers/pddbtsc.py b/src/serinv/wrappers/pddbtsc.py index b846bb07..c1ec3b72 100644 --- a/src/serinv/wrappers/pddbtsc.py +++ b/src/serinv/wrappers/pddbtsc.py @@ -21,6 +21,7 @@ def pddbtsc( A_lower_diagonal_blocks: ArrayLike, A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel Schur-complement of a block tridiagonal matrix. @@ -84,6 +85,7 @@ def pddbtsc( - _B_upper_diagonal_blocks : ArrayLike The upper diagonal blocks of the reduced system. """ + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -163,6 +165,7 @@ def pddbtsc( quadratic=quadratic, buffers=buffers, _rhs=ddbtrs.get("_rhs", None), + nccl_comm=nccl_comm, ) MPI.COMM_WORLD.Barrier() @@ -172,7 +175,10 @@ def pddbtsc( quadratic=quadratic, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() MPI.COMM_WORLD.Barrier() toc = time.perf_counter() elapsed = toc - tic diff --git a/src/serinv/wrappers/pddbtsci.py b/src/serinv/wrappers/pddbtsci.py index 144054f8..63d94c84 100644 --- a/src/serinv/wrappers/pddbtsci.py +++ b/src/serinv/wrappers/pddbtsci.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import ddbtsci @@ -19,6 +18,7 @@ def pddbtsci( A_lower_diagonal_blocks: ArrayLike, A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel selected-inversion of the Schur-complement of a block tridiagonal matrix. @@ -89,8 +89,6 @@ def pddbtsci( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") - xp, _ = _get_module_from_array(arr=A_diagonal_blocks) - rhs: dict = kwargs.get("rhs", None) quadratic: bool = kwargs.get("quadratic", False) buffers: dict = kwargs.get("buffers", None) @@ -125,6 +123,7 @@ def pddbtsci( ddbtrs=ddbtrs, comm=comm, quadratic=quadratic, + nccl_comm=nccl_comm, ) map_ddbtrs_to_ddbtsci( @@ -139,6 +138,7 @@ def pddbtsci( quadratic=quadratic, buffers=buffers, _rhs=ddbtrs.get("_rhs", None), + nccl_comm=nccl_comm, ) # Perform distributed SCI diff --git a/src/serinv/wrappers/pobtars.py b/src/serinv/wrappers/pobtars.py index 98cfdf46..dffa86e8 100644 --- a/src/serinv/wrappers/pobtars.py +++ b/src/serinv/wrappers/pobtars.py @@ -15,9 +15,6 @@ import cupyx as cpx import cupy as cp - if backend_flags["nccl_avail"]: - from cupy.cuda import nccl - def allocate_pobtars( A_diagonal_blocks: ArrayLike, @@ -29,6 +26,7 @@ def allocate_pobtars( B: ArrayLike = None, device_streaming: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Allocate the buffers necessary for the reduced system of the PPOBTARX algorithms. @@ -56,6 +54,12 @@ def allocate_pobtars( pobtars : dict Dictionary containing the reduced system arrays. """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -106,7 +110,7 @@ def allocate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): _A_diagonal_blocks_comm = cpx.empty_like_pinned(_A_diagonal_blocks) _A_lower_diagonal_blocks_comm = cpx.empty_like_pinned(_A_lower_diagonal_blocks) @@ -152,6 +156,7 @@ def map_ppobtax_to_pobtars( buffer: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the the boundary blocks of the PPOBTAX algorithm to the reduced system. @@ -178,6 +183,12 @@ def map_ppobtax_to_pobtars( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -228,9 +239,16 @@ def map_ppobtas_to_pobtarss( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTAS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -252,6 +270,7 @@ def aggregate_pobtars( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Aggregate the reduced system. @@ -269,7 +288,14 @@ def aggregate_pobtars( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() _A_diagonal_blocks: ArrayLike = pobtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = pobtars.get("A_lower_diagonal_blocks", None) @@ -306,7 +332,7 @@ def aggregate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. _A_diagonal_blocks.get(out=_A_diagonal_blocks_comm) @@ -317,11 +343,12 @@ def aggregate_pobtars( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -329,9 +356,9 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -339,9 +366,9 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_arrow_blocks_comm.data.ptr, count=count, @@ -349,17 +376,18 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_A_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_A_arrow_tip_block_comm.data.ptr, recvbuf=_A_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -374,7 +402,7 @@ def aggregate_pobtars( ) comm.Allreduce(MPI.IN_PLACE, _A_arrow_tip_block_comm, op=MPI.SUM) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -432,7 +460,7 @@ def aggregate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -449,8 +477,15 @@ def aggregate_pobtarss( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -474,7 +509,7 @@ def aggregate_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the # HOST pinned arrays. @@ -484,10 +519,11 @@ def aggregate_pobtarss( if strategy == "allgather": if _use_nccl(comm): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm[:-a], comm=comm, op="allgather" + arr=_B_comm[:-a], comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_comm[:-a].data.ptr + displacement, recvbuf=_B_comm[:-a].data.ptr, count=count, @@ -495,24 +531,25 @@ def aggregate_pobtarss( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm[-a:], comm=comm, op="allreduce" + arr=_B_comm[-a:], comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_B_comm[-a:].data.ptr, recvbuf=_B_comm[-a:].data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_comm[:-a], ) comm.Allreduce(MPI.IN_PLACE, _B_comm[-a:], op=MPI.SUM) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -547,7 +584,7 @@ def aggregate_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system RHS on the GPU _B.set(arr=_B_comm) @@ -559,10 +596,18 @@ def scatter_pobtars( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() _A_diagonal_blocks: ArrayLike = pobtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = pobtars.get("A_lower_diagonal_blocks", None) @@ -598,6 +643,11 @@ def scatter_pobtars( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -607,7 +657,7 @@ def scatter_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -652,7 +702,7 @@ def scatter_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -671,9 +721,17 @@ def scatter_pobtarss( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() b = A_diagonal_blocks[0].shape[0] a = A_arrow_tip_block.shape[0] @@ -695,6 +753,11 @@ def scatter_pobtarss( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -703,7 +766,7 @@ def scatter_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -727,7 +790,7 @@ def scatter_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _B.set(arr=_B_comm) @@ -747,6 +810,7 @@ def map_pobtars_to_ppobtax( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Map the reduced system back to the original system. @@ -772,6 +836,12 @@ def map_pobtars_to_ppobtax( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -824,9 +894,16 @@ def map_pobtarss_to_ppobtas( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTAS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/pobtrs.py b/src/serinv/wrappers/pobtrs.py index c716317a..e3f48a2e 100644 --- a/src/serinv/wrappers/pobtrs.py +++ b/src/serinv/wrappers/pobtrs.py @@ -24,6 +24,7 @@ def allocate_pobtrs( B: ArrayLike = None, device_streaming: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Allocate the buffers necessary for the reduced system of the PpobtRX algorithms. @@ -47,6 +48,12 @@ def allocate_pobtrs( pobtrs : dict Dictionary containing the reduced system arrays. """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -92,7 +99,7 @@ def allocate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): _A_diagonal_blocks_comm = cpx.empty_like_pinned(_A_diagonal_blocks) _A_lower_diagonal_blocks_comm = cpx.empty_like_pinned(_A_lower_diagonal_blocks) @@ -126,6 +133,7 @@ def map_ppobtx_to_pobtrs( comm: MPI.Comm, buffer: ArrayLike, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: """Map the the boundary blocks of the PpobtX algorithm to the reduced system. @@ -144,6 +152,12 @@ def map_ppobtx_to_pobtrs( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -183,9 +197,16 @@ def map_ppobts_to_pobtrss( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -204,6 +225,7 @@ def aggregate_pobtrs( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Aggregate the reduced system. @@ -215,6 +237,12 @@ def aggregate_pobtrs( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() _A_diagonal_blocks: ArrayLike = pobtrs.get("A_diagonal_blocks", None) @@ -241,7 +269,7 @@ def aggregate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. _A_diagonal_blocks.get(out=_A_diagonal_blocks_comm) @@ -250,11 +278,12 @@ def aggregate_pobtrs( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -262,9 +291,9 @@ def aggregate_pobtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -272,6 +301,7 @@ def aggregate_pobtrs( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -281,7 +311,7 @@ def aggregate_pobtrs( _A_lower_diagonal_blocks_comm, ) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -322,7 +352,7 @@ def aggregate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -336,8 +366,15 @@ def aggregate_pobtrss( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -360,7 +397,7 @@ def aggregate_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the # HOST pinned arrays. @@ -369,11 +406,12 @@ def aggregate_pobtrss( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm, comm=comm, op="allgather" + arr=_B_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_comm.data.ptr + displacement, recvbuf=_B_comm.data.ptr, count=count, @@ -381,12 +419,13 @@ def aggregate_pobtrss( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_comm, ) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -415,7 +454,7 @@ def aggregate_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system RHS on the GPU _B.set(arr=_B_comm) @@ -427,9 +466,16 @@ def scatter_pobtrs( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() _A_diagonal_blocks: ArrayLike = pobtrs.get("A_diagonal_blocks", None) @@ -456,6 +502,11 @@ def scatter_pobtrs( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -465,7 +516,7 @@ def scatter_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -496,7 +547,7 @@ def scatter_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -513,8 +564,15 @@ def scatter_pobtrss( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -536,6 +594,11 @@ def scatter_pobtrss( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -544,7 +607,7 @@ def scatter_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -564,7 +627,7 @@ def scatter_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _B.set(arr=_B_comm) @@ -580,6 +643,7 @@ def map_pobtrs_to_ppobtx( _A_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Map the reduced system back to the original system. @@ -597,6 +661,12 @@ def map_pobtrs_to_ppobtx( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -638,9 +708,16 @@ def map_pobtrss_to_ppobts( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/ppobtaf.py b/src/serinv/wrappers/ppobtaf.py index 64946a3d..240395e0 100644 --- a/src/serinv/wrappers/ppobtaf.py +++ b/src/serinv/wrappers/ppobtaf.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -20,6 +22,7 @@ def ppobtaf( A_lower_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel factorization of a block tridiagonal with arrowhead matrix @@ -62,6 +65,8 @@ def ppobtaf( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + # Check for optional parameters device_streaming: bool = kwargs.get("device_streaming", False) strategy: str = kwargs.get("strategy", "allgather") @@ -133,14 +138,23 @@ def ppobtaf( buffer=buffer, strategy=strategy, comm=comm, + nccl_comm=nccl_comm, ) + MPI.COMM_WORLD.Barrier() + tic = time.perf_counter() aggregate_pobtars( pobtars=pobtars, comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + MPI.COMM_WORLD.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # --- Factorize the reduced system --- pobtars["A_arrow_tip_block"][:] += A_arrow_tip_initial @@ -165,3 +179,5 @@ def ppobtaf( ) comm.Barrier() + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtas.py b/src/serinv/wrappers/ppobtas.py index 132a9d08..2feda2a2 100644 --- a/src/serinv/wrappers/ppobtas.py +++ b/src/serinv/wrappers/ppobtas.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -23,6 +25,7 @@ def ppobtas( L_arrow_tip_block: ArrayLike, B: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -138,16 +141,25 @@ def ppobtas( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Agregate reduced RHS + MPI.COMM_WORLD.Barrier() + tic = time.perf_counter() aggregate_pobtarss( A_diagonal_blocks=L_diagonal_blocks, A_arrow_tip_block=L_arrow_tip_block, pobtars=pobtars, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + MPI.COMM_WORLD.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Add the tip block of the RHS to the aggregated update _B[-a:] += B_tip_initial @@ -180,6 +192,7 @@ def ppobtas( pobtars=pobtars, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Map solution of reduced RHS to RHS @@ -190,6 +203,7 @@ def ppobtas( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel backward solve @@ -213,3 +227,5 @@ def ppobtas( buffer=buffer, trans="C", ) + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtasi.py b/src/serinv/wrappers/ppobtasi.py index 12e23e14..96b67c20 100644 --- a/src/serinv/wrappers/ppobtasi.py +++ b/src/serinv/wrappers/ppobtasi.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import pobtasi @@ -20,6 +19,7 @@ def ppobtasi( L_lower_arrow_blocks: ArrayLike, L_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -114,6 +114,7 @@ def ppobtasi( comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) # Map result of the reduced system back to the original system @@ -129,6 +130,7 @@ def ppobtasi( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel selected inversion of the original system diff --git a/src/serinv/wrappers/ppobtf.py b/src/serinv/wrappers/ppobtf.py index 2d680c96..b8f8e839 100644 --- a/src/serinv/wrappers/ppobtf.py +++ b/src/serinv/wrappers/ppobtf.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -18,6 +20,7 @@ def ppobtf( A_diagonal_blocks: ArrayLike, A_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel factorization of a block tridiagonal with arrowhead matrix @@ -56,6 +59,8 @@ def ppobtf( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + # Check for optional parameters device_streaming: bool = kwargs.get("device_streaming", False) strategy: str = kwargs.get("strategy", "allgather") @@ -106,14 +111,23 @@ def ppobtf( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) + MPI.COMM_WORLD.Barrier() + tic = time.perf_counter() aggregate_pobtrs( pobtrs=pobtrs, comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + MPI.COMM_WORLD.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # --- Factorize the reduced system --- if strategy == "gather-scatter": @@ -134,3 +148,5 @@ def ppobtf( ) comm.Barrier() + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobts.py b/src/serinv/wrappers/ppobts.py index 396577f1..bef5b48e 100644 --- a/src/serinv/wrappers/ppobts.py +++ b/src/serinv/wrappers/ppobts.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -21,6 +23,7 @@ def ppobts( L_lower_diagonal_blocks: ArrayLike, B: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -119,14 +122,24 @@ def ppobts( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + # Agregate reduced RHS + MPI.COMM_WORLD.Barrier() + tic = time.perf_counter() aggregate_pobtrss( A_diagonal_blocks=L_diagonal_blocks, pobtrs=pobtrs, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + MPI.COMM_WORLD.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Solve RHS FWD/BWD if strategy == "allgather": @@ -151,6 +164,7 @@ def ppobts( pobtrs=pobtrs, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Map solution of reduced RHS to RHS @@ -160,6 +174,7 @@ def ppobts( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel backward solve @@ -179,3 +194,5 @@ def ppobts( buffer=buffer, trans="C", ) + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtsi.py b/src/serinv/wrappers/ppobtsi.py index fde320fc..72c7354d 100644 --- a/src/serinv/wrappers/ppobtsi.py +++ b/src/serinv/wrappers/ppobtsi.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import pobtsi @@ -18,6 +17,7 @@ def ppobtsi( L_diagonal_blocks: ArrayLike, L_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -100,6 +100,7 @@ def ppobtsi( comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) # Map result of the reduced system back to the original system @@ -111,6 +112,7 @@ def ppobtsi( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel selected inversion of the original system From 2d66457349d4dafc3ec4aa54153898afec2a086d Mon Sep 17 00:00:00 2001 From: vincent-maillou Date: Sat, 5 Apr 2025 22:41:34 +0200 Subject: [PATCH 003/518] removed explicit comm world call --- src/serinv/wrappers/pddbtasc.py | 4 ++-- src/serinv/wrappers/pddbtsc.py | 4 ++-- src/serinv/wrappers/ppobtaf.py | 4 ++-- src/serinv/wrappers/ppobtas.py | 4 ++-- src/serinv/wrappers/ppobtf.py | 4 ++-- src/serinv/wrappers/ppobts.py | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/serinv/wrappers/pddbtasc.py b/src/serinv/wrappers/pddbtasc.py index 6c756c92..27335195 100644 --- a/src/serinv/wrappers/pddbtasc.py +++ b/src/serinv/wrappers/pddbtasc.py @@ -207,7 +207,7 @@ def pddbtasc( nccl_comm=nccl_comm, ) - MPI.COMM_WORLD.Barrier() + comm.Barrier() tic = time.perf_counter() aggregate_ddbtars( ddbtars=ddbtars, @@ -218,7 +218,7 @@ def pddbtasc( ) if xp.__name__ == "cupy": xp.cuda.runtime.deviceSynchronize() - MPI.COMM_WORLD.Barrier() + comm.Barrier() toc = time.perf_counter() elapsed = toc - tic diff --git a/src/serinv/wrappers/pddbtsc.py b/src/serinv/wrappers/pddbtsc.py index c1ec3b72..e9f1eb9e 100644 --- a/src/serinv/wrappers/pddbtsc.py +++ b/src/serinv/wrappers/pddbtsc.py @@ -168,7 +168,7 @@ def pddbtsc( nccl_comm=nccl_comm, ) - MPI.COMM_WORLD.Barrier() + comm.Barrier() tic = time.perf_counter() aggregate_ddbtrs( ddbtrs=ddbtrs, @@ -179,7 +179,7 @@ def pddbtsc( ) if xp.__name__ == "cupy": xp.cuda.runtime.deviceSynchronize() - MPI.COMM_WORLD.Barrier() + comm.Barrier() toc = time.perf_counter() elapsed = toc - tic diff --git a/src/serinv/wrappers/ppobtaf.py b/src/serinv/wrappers/ppobtaf.py index 240395e0..2122b88a 100644 --- a/src/serinv/wrappers/ppobtaf.py +++ b/src/serinv/wrappers/ppobtaf.py @@ -141,7 +141,7 @@ def ppobtaf( nccl_comm=nccl_comm, ) - MPI.COMM_WORLD.Barrier() + comm.Barrier() tic = time.perf_counter() aggregate_pobtars( pobtars=pobtars, @@ -152,7 +152,7 @@ def ppobtaf( ) if xp.__name__ == "cupy": xp.cuda.runtime.deviceSynchronize() - MPI.COMM_WORLD.Barrier() + comm.Barrier() toc = time.perf_counter() elapsed = toc - tic diff --git a/src/serinv/wrappers/ppobtas.py b/src/serinv/wrappers/ppobtas.py index 2feda2a2..52896530 100644 --- a/src/serinv/wrappers/ppobtas.py +++ b/src/serinv/wrappers/ppobtas.py @@ -145,7 +145,7 @@ def ppobtas( ) # Agregate reduced RHS - MPI.COMM_WORLD.Barrier() + comm.Barrier() tic = time.perf_counter() aggregate_pobtarss( A_diagonal_blocks=L_diagonal_blocks, @@ -157,7 +157,7 @@ def ppobtas( ) if xp.__name__ == "cupy": xp.cuda.runtime.deviceSynchronize() - MPI.COMM_WORLD.Barrier() + comm.Barrier() toc = time.perf_counter() elapsed = toc - tic diff --git a/src/serinv/wrappers/ppobtf.py b/src/serinv/wrappers/ppobtf.py index b8f8e839..f1b44956 100644 --- a/src/serinv/wrappers/ppobtf.py +++ b/src/serinv/wrappers/ppobtf.py @@ -114,7 +114,7 @@ def ppobtf( nccl_comm=nccl_comm, ) - MPI.COMM_WORLD.Barrier() + comm.Barrier() tic = time.perf_counter() aggregate_pobtrs( pobtrs=pobtrs, @@ -125,7 +125,7 @@ def ppobtf( ) if xp.__name__ == "cupy": xp.cuda.runtime.deviceSynchronize() - MPI.COMM_WORLD.Barrier() + comm.Barrier() toc = time.perf_counter() elapsed = toc - tic diff --git a/src/serinv/wrappers/ppobts.py b/src/serinv/wrappers/ppobts.py index bef5b48e..86906883 100644 --- a/src/serinv/wrappers/ppobts.py +++ b/src/serinv/wrappers/ppobts.py @@ -126,7 +126,7 @@ def ppobts( ) # Agregate reduced RHS - MPI.COMM_WORLD.Barrier() + comm.Barrier() tic = time.perf_counter() aggregate_pobtrss( A_diagonal_blocks=L_diagonal_blocks, @@ -137,7 +137,7 @@ def ppobts( ) if xp.__name__ == "cupy": xp.cuda.runtime.deviceSynchronize() - MPI.COMM_WORLD.Barrier() + comm.Barrier() toc = time.perf_counter() elapsed = toc - tic From 791daa7bc32eca05c368d013cfaf91f9e5051e66 Mon Sep 17 00:00:00 2001 From: vincent-maillou Date: Sat, 5 Apr 2025 23:58:04 +0200 Subject: [PATCH 004/518] added nccl synch --- src/serinv/wrappers/pobtars.py | 3 +++ src/serinv/wrappers/pobtrs.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/serinv/wrappers/pobtars.py b/src/serinv/wrappers/pobtars.py index dffa86e8..ee6f26bb 100644 --- a/src/serinv/wrappers/pobtars.py +++ b/src/serinv/wrappers/pobtars.py @@ -345,6 +345,7 @@ def aggregate_pobtars( if strategy == "allgather": if _use_nccl(communicator): # --- Use NCCL --- + cp.cuda.runtime.deviceSynchronize() count, displacement, datatype = _get_nccl_parameters( arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) @@ -386,6 +387,8 @@ def aggregate_pobtars( op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) + cp.cuda.runtime.deviceSynchronize() + comm.Barrier() else: # --- Use MPI --- comm.Allgather( diff --git a/src/serinv/wrappers/pobtrs.py b/src/serinv/wrappers/pobtrs.py index e3f48a2e..391953d4 100644 --- a/src/serinv/wrappers/pobtrs.py +++ b/src/serinv/wrappers/pobtrs.py @@ -280,6 +280,7 @@ def aggregate_pobtrs( if strategy == "allgather": if _use_nccl(communicator): # --- Use NCCL --- + cp.cuda.runtime.deviceSynchronize() count, displacement, datatype = _get_nccl_parameters( arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) @@ -300,6 +301,8 @@ def aggregate_pobtrs( datatype=datatype, stream=cp.cuda.Stream.null.ptr, ) + cp.cuda.runtime.deviceSynchronize() + comm.Barrier() else: # --- Use MPI --- comm.Allgather( From 30cf672070bacb48f520044f42e63773755b81ef Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 13:33:57 +0000 Subject: [PATCH 005/518] first version of pobtas streaming --- src/serinv/algs/pobtas.py | 214 +++++++++++++++++- .../regular/tests_bta/test_pobtas.py | 27 ++- 2 files changed, 238 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index bab2a911..e0c87b82 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -4,6 +4,7 @@ from serinv import ( ArrayLike, _get_module_from_array, + _get_module_from_str, ) @@ -47,8 +48,14 @@ def pobtas( else: # Natural arrowhead if device_streaming: - raise NotImplementedError( - "Streaming is not implemented for the natural arrowhead." + _pobtas_streaming( + L_diagonal_blocks, + L_lower_diagonal_blocks, + L_lower_arrow_blocks, + L_arrow_tip_block, + B, + trans, + partial, ) else: _pobtas( @@ -216,3 +223,206 @@ def _pobtas_permuted( ) else: raise ValueError(f"Invalid transpose argument: {trans}.") + +def _pobtas_streaming( + L_diagonal_blocks: ArrayLike, + L_lower_diagonal_blocks: ArrayLike, + L_lower_arrow_blocks: ArrayLike, + L_arrow_tip_block: ArrayLike, + B: ArrayLike, + trans: str, + partial: bool, +): + arr_module, _ = _get_module_from_array(arr=L_diagonal_blocks) + if arr_module.__name__ != "numpy": + raise NotImplementedError( + "Host<->Device streaming only works when host-arrays are given." + ) + + cp, cu_la = _get_module_from_str(module_str="cupy") + + # Streams and events + compute_stream = cp.cuda.Stream(non_blocking=True) + h2d_stream = cp.cuda.Stream(non_blocking=True) + d2h_stream = cp.cuda.Stream(non_blocking=True) + + h2d_diagonal_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_lower_diagonal_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_arrow_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_B_events = [cp.cuda.Event(), cp.cuda.Event()] + + d2h_B_events = [cp.cuda.Event(), cp.cuda.Event()] + + compute_current_B_events = [cp.cuda.Event(), cp.cuda.Event()] + compute_next_B_events = [cp.cuda.Event(), cp.cuda.Event()] + compute_arrow_B_events = [cp.cuda.Event(), cp.cuda.Event()] + + compute_partial_events = [cp.cuda.Event(), cp.cuda.Event()] + + #compute_arrow_events = [cp.cuda.Event(), cp.cuda.Event()] + #compute_arrow_h2d_events = [cp.cuda.Event(), cp.cuda.Event()] + #compute_B_events = [cp.cuda.Event(), cp.cuda.Event()] + #compute_B_h2d_events = [cp.cuda.Event(), cp.cuda.Event()] + + # Vars + diag_blocksize = L_diagonal_blocks.shape[1] + arrow_blocksize = L_lower_arrow_blocks.shape[1] + n_diag_blocks = L_diagonal_blocks.shape[0] + + # Device Buffers + # B Buffers + B_shape = B[0 : diag_blocksize] # block template + B_d = cp.empty( + (2, *B_shape.shape[1:]), dtype=B_shape.dtype + ) + B_shape = B[-arrow_blocksize:] + B_last_block_d = cp.empty_like(B_shape) + del B_shape + + # L Buffers + L_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_lower_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_lower_arrow_blocks_d = cp.empty( + (2, *L_lower_arrow_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_arrow_tip_block_d = cp.empty_like(L_arrow_tip_block) + + # Forward Pass + # --- C: events + transfers --- + compute_current_B_events[1].record(stream=compute_stream) + compute_next_B_events[1].record(stream=compute_stream) + compute_arrow_B_events[1].record(stream=compute_stream) + + B_last_block_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) + L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:, :], stream=h2d_stream) + + # --- H2D: transfers --- + B_d[0].set(arr=B[0 : 1 * diag_blocksize], stream = h2d_stream) + h2d_B_events[0].record(stream=h2d_stream) + + L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream) + h2d_diagonal_events[0].record(stream=h2d_stream) + + L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) + h2d_lower_diagonal_events[0].record(stream=h2d_stream) + + L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[0], stream=h2d_stream) + h2d_arrow_events[0].record(stream=h2d_stream) + + # --- D2H: event --- + d2h_B_events[1].record(stream=d2h_stream) + + n_diag_blocks: int = L_diagonal_blocks.shape[0] # why? + if n_diag_blocks > 1: + L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) + h2d_lower_diagonal_events[0].record(stream=h2d_stream) + + + + if trans == "N": + for i in range(0, n_diag_blocks-1): + # --- Forward substitution --- + with compute_stream: + # Compute step 1 : compute B + compute_stream.wait_event(h2d_diagonal_events[i % 2]) + compute_stream.wait_event(compute_arrow_B_events[i % 2]) + compute_stream.wait_event(compute_current_B_events[(i + 1) % 2]) + B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize] = cu_la.solve_triangular( + L_diagonal_blocks[i % 2], + B[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize], + lower=True, + ) + compute_current_B_events[i % 2].record(stream=compute_stream) + + h2d_stream.wait_event(compute_current_B_events[i % 2]) + L_diagonal_blocks_d[(i + 2) % 2].set(arr=L_diagonal_blocks[i + 2], stream=h2d_stream) + h2d_diagonal_events[i % 2].record(stream=h2d_stream) + + d2h_stream.wait_event(compute_next_B_events[i % 2]) + B_d[i % 2].get( + out=B[i * diag_blocksize : (i + 1) * diag_blocksize], + stream=d2h_stream, + blocking=False, + ) + d2h_B_events[i % 2].record(stream=d2h_stream) + + with compute_stream: + # 2 + compute_stream.wait_event(h2d_lower_diagonal_events[i % 2]) + compute_stream.wait_event(h2d_B_events[(i + 1) % 2]) + compute_stream.wait_event(compute_current_B_events[i % 2]) + compute_stream.wait_event(compute_next_B_events[(i + 1) % 2]) + B_d[(i + 1) % 2 * diag_blocksize : (i + 2) % 2 * diag_blocksize] -= ( + L_lower_diagonal_blocks[i%2] + @ B[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize] + ) + compute_next_B_events[i % 2].record(stream=compute_stream) + + h2d_stream.wait_event(compute_next_B_events[i % 2]) + L_lower_diagonal_blocks_d[(i + 2) % 2].set(arr=L_lower_diagonal_blocks[i + 2], stream=h2d_stream) + h2d_lower_diagonal_events[i % 2].record(stream=h2d_stream) + + with compute_stream: + # 3 + compute_stream.wait_event(h2d_arrow_events[i % 2]) + compute_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) + compute_stream.wait_event(compute_next_B_events[i % 2]) + B_last_block_d -= ( + L_lower_arrow_blocks_d[i % 2] + @ B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize] + ) + compute_arrow_B_events[i % 2].record(stream=compute_stream) + + h2d_stream.wait_event(compute_arrow_B_events[i % 2]) + B_d[(i + 2) % 2].set(arr=B[(i + 2) * diag_blocksize : (i + 3) * diag_blocksize], stream = h2d_stream) + h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) + + L_lower_arrow_blocks_d[(i + 1) % 2].set(arr=L_lower_arrow_blocks[i + 1], stream=h2d_stream) + h2d_arrow_events[i % 2].record(stream=h2d_stream) + + + if not partial: + # In the case of the partial solve, we do not solve the last block and + # arrow tip block of the RHS. + + L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream) + h2d_diagonal_events[0].record(stream=h2d_stream) + + L_lower_arrow_blocks_d[0].set(arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream) + h2d_arrow_events[0].record(stream=h2d_stream) + + + with compute_stream: + + compute_stream.wait_event(h2d_diagonal_events[0]) + B_last_block_d = (cu_la.solve_triangular(L_diagonal_blocks_d[0], B_d[0], lower=True,)) + compute_partial_events[0].record(stream=compute_stream) + + compute_stream.wait_event(h2d_arrow_events[0]) + compute_stream.wait_event(compute_partial_events[0]) + B_last_block_d -= (L_lower_arrow_blocks_d[-1] @ B_last_block_d[1]) + compute_partial_events[1].record(stream=compute_stream) + + d2h_stream.wait_event(compute_partial_events[1]) + B_d[i % 2].get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + + # Y_{ndb+1} = L_{ndb+1,ndb+1}^{-1} (B_{ndb+1} - \Sigma_{i=1}^{ndb} L_{ndb+1,i} Y_{i) + + elif trans == "T" or trans == "C": + # ----- Backward substitution ----- + if not partial: + # X_{ndb+1} = L_{ndb+1,ndb+1}^{-T} (Y_{ndb+1}) + raise NotImplementedError( + "T and C not yet implemented." + ) + # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) + + # for i in range(n_diag_blocks -2, -1, -1): + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + + else: + raise ValueError(f"Invalid transpose argument: {trans}.") \ No newline at end of file diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index 647a0168..b8810c77 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -3,11 +3,14 @@ import numpy as np import pytest -from serinv import _get_module_from_array +from serinv import backend_flags, _get_module_from_array from ....testing_utils import bta_dense_to_arrays, dd_bta, symmetrize, rhs from serinv.algs import pobtaf, pobtas +if backend_flags["cupy_avail"]: + import cupyx as cpx + @pytest.mark.mpi_skip() @pytest.mark.parametrize("n_rhs", [1, 2, 3]) @@ -19,6 +22,9 @@ def test_pobtas( array_type: str, dtype: np.dtype, ): + + array_type = "streaming" + A = dd_bta( diagonal_blocksize, arrowhead_blocksize, @@ -51,6 +57,24 @@ def test_pobtas( A_arrow_tip_block, ) = bta_dense_to_arrays(A, diagonal_blocksize, arrowhead_blocksize, n_diag_blocks) + if backend_flags["cupy_avail"] and array_type == "streaming": + A_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_diagonal_blocks) + A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks[:, :, :] + A_lower_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_lower_diagonal_blocks) + A_lower_diagonal_blocks_pinned[:, :, :] = A_lower_diagonal_blocks[:, :, :] + A_lower_arrow_blocks_pinned = cpx.zeros_like_pinned(A_lower_arrow_blocks) + A_lower_arrow_blocks_pinned[:, :, :] = A_lower_arrow_blocks[:, :, :] + A_arrow_tip_block_pinned = cpx.zeros_like_pinned(A_arrow_tip_block) + A_arrow_tip_block_pinned[:, :] = A_arrow_tip_block[:, :] + B_pinned = cpx.zeros_like_pinned(B) + B_pinned[:, :] = B[:, :] + + A_diagonal_blocks = A_diagonal_blocks_pinned + A_lower_diagonal_blocks = A_lower_diagonal_blocks_pinned + A_lower_arrow_blocks = A_lower_arrow_blocks_pinned + A_arrow_tip_block = A_arrow_tip_block_pinned + B = B_pinned + pobtaf( A_diagonal_blocks, A_lower_diagonal_blocks, @@ -66,6 +90,7 @@ def test_pobtas( A_arrow_tip_block, B, trans="N", + device_streaming=True if array_type == "streaming" else False, ) # Backward solve: X=L^{-T}Y From 9e97c5e5030d65c26a8c27145c1fab0a1658307e Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 13:36:35 +0000 Subject: [PATCH 006/518] change tests incase that broke it --- tests/tests_algs/regular/tests_bta/test_pobtas.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index b8810c77..94f73c2a 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -23,7 +23,6 @@ def test_pobtas( dtype: np.dtype, ): - array_type = "streaming" A = dd_bta( diagonal_blocksize, From 380a08492aa3cc1777cb74ef90020e773ae3d85f Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 13:38:32 +0000 Subject: [PATCH 007/518] test update --- tests/tests_algs/regular/tests_bta/test_pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index 94f73c2a..ce9ac254 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -22,7 +22,7 @@ def test_pobtas( array_type: str, dtype: np.dtype, ): - + array_type == "streaming" A = dd_bta( diagonal_blocksize, From 320ebef0661e40c7ac5df949c62e6b7b6c950981 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 13:43:10 +0000 Subject: [PATCH 008/518] typo --- tests/tests_algs/regular/tests_bta/test_pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index ce9ac254..cbb7aaed 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -22,7 +22,7 @@ def test_pobtas( array_type: str, dtype: np.dtype, ): - array_type == "streaming" + array_type = "streaming" A = dd_bta( diagonal_blocksize, From 6eeb0f9c58cfc635cbd86b317eccd155b2d9b463 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 14:23:36 +0000 Subject: [PATCH 009/518] debug statements --- src/serinv/algs/pobtas.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index e0c87b82..b3b4ee4a 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -272,9 +272,11 @@ def _pobtas_streaming( # Device Buffers # B Buffers B_shape = B[0 : diag_blocksize] # block template + print(B_shape) B_d = cp.empty( (2, *B_shape.shape[1:]), dtype=B_shape.dtype ) + print(B_d) B_shape = B[-arrow_blocksize:] B_last_block_d = cp.empty_like(B_shape) del B_shape From 137112fb0d8e5acafa3b5e595122ab498938399f Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 14:24:55 +0000 Subject: [PATCH 010/518] debug changes --- src/serinv/algs/pobtas.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index b3b4ee4a..e50e05f9 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -272,10 +272,12 @@ def _pobtas_streaming( # Device Buffers # B Buffers B_shape = B[0 : diag_blocksize] # block template + print("B_shape") print(B_shape) B_d = cp.empty( (2, *B_shape.shape[1:]), dtype=B_shape.dtype ) + print("B_d") print(B_d) B_shape = B[-arrow_blocksize:] B_last_block_d = cp.empty_like(B_shape) From dfab4ab20897451211e050eddadff65599bda3ba Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 14:27:13 +0000 Subject: [PATCH 011/518] debug messages --- src/serinv/algs/pobtas.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index e50e05f9..8b45d3d9 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -305,6 +305,10 @@ def _pobtas_streaming( L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:, :], stream=h2d_stream) # --- H2D: transfers --- + print("B block") + print(B[0 : 1 * diag_blocksize]) + print("B_d 0") + print(B_d[0]) B_d[0].set(arr=B[0 : 1 * diag_blocksize], stream = h2d_stream) h2d_B_events[0].record(stream=h2d_stream) From e068b847d0e7c191a83be0f154d99725fec29745 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 14:30:42 +0000 Subject: [PATCH 012/518] print B --- src/serinv/algs/pobtas.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 8b45d3d9..c4ac087e 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -307,6 +307,8 @@ def _pobtas_streaming( # --- H2D: transfers --- print("B block") print(B[0 : 1 * diag_blocksize]) + print("B") + print(B) print("B_d 0") print(B_d[0]) B_d[0].set(arr=B[0 : 1 * diag_blocksize], stream = h2d_stream) From 2981b3edca9105708d47d44f0053e32e92c50a1c Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 14:31:52 +0000 Subject: [PATCH 013/518] changed B_d shape --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index c4ac087e..0c1c0d64 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -275,7 +275,7 @@ def _pobtas_streaming( print("B_shape") print(B_shape) B_d = cp.empty( - (2, *B_shape.shape[1:]), dtype=B_shape.dtype + (2, *B_shape.shape), dtype=B_shape.dtype ) print("B_d") print(B_d) From 29c674fe4972716cf9c34906010cdb4efb5e73b6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 14:38:55 +0000 Subject: [PATCH 014/518] changed wrong arrays in streaming --- src/serinv/algs/pobtas.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 0c1c0d64..3035e889 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -259,11 +259,6 @@ def _pobtas_streaming( compute_partial_events = [cp.cuda.Event(), cp.cuda.Event()] - #compute_arrow_events = [cp.cuda.Event(), cp.cuda.Event()] - #compute_arrow_h2d_events = [cp.cuda.Event(), cp.cuda.Event()] - #compute_B_events = [cp.cuda.Event(), cp.cuda.Event()] - #compute_B_h2d_events = [cp.cuda.Event(), cp.cuda.Event()] - # Vars diag_blocksize = L_diagonal_blocks.shape[1] arrow_blocksize = L_lower_arrow_blocks.shape[1] @@ -272,13 +267,9 @@ def _pobtas_streaming( # Device Buffers # B Buffers B_shape = B[0 : diag_blocksize] # block template - print("B_shape") - print(B_shape) B_d = cp.empty( (2, *B_shape.shape), dtype=B_shape.dtype ) - print("B_d") - print(B_d) B_shape = B[-arrow_blocksize:] B_last_block_d = cp.empty_like(B_shape) del B_shape @@ -305,12 +296,6 @@ def _pobtas_streaming( L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:, :], stream=h2d_stream) # --- H2D: transfers --- - print("B block") - print(B[0 : 1 * diag_blocksize]) - print("B") - print(B) - print("B_d 0") - print(B_d[0]) B_d[0].set(arr=B[0 : 1 * diag_blocksize], stream = h2d_stream) h2d_B_events[0].record(stream=h2d_stream) @@ -331,7 +316,6 @@ def _pobtas_streaming( L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) h2d_lower_diagonal_events[0].record(stream=h2d_stream) - if trans == "N": for i in range(0, n_diag_blocks-1): @@ -342,8 +326,8 @@ def _pobtas_streaming( compute_stream.wait_event(compute_arrow_B_events[i % 2]) compute_stream.wait_event(compute_current_B_events[(i + 1) % 2]) B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize] = cu_la.solve_triangular( - L_diagonal_blocks[i % 2], - B[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize], + L_diagonal_blocks_d[i % 2], + B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize], lower=True, ) compute_current_B_events[i % 2].record(stream=compute_stream) From 3bffe7d8ab696fa83c7a24be492868e8fa1944fc Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 14:41:45 +0000 Subject: [PATCH 015/518] debug shapes --- src/serinv/algs/pobtas.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 3035e889..33e4aea3 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -325,6 +325,8 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_diagonal_events[i % 2]) compute_stream.wait_event(compute_arrow_B_events[i % 2]) compute_stream.wait_event(compute_current_B_events[(i + 1) % 2]) + print(B_d.shape()) + print(L_diagonal_blocks_d.shape()) B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize], From 1e8acadb0d74e1c619ca05404ea7b59775374443 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 14:43:22 +0000 Subject: [PATCH 016/518] typo --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 33e4aea3..07e01c89 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -325,8 +325,8 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_diagonal_events[i % 2]) compute_stream.wait_event(compute_arrow_B_events[i % 2]) compute_stream.wait_event(compute_current_B_events[(i + 1) % 2]) - print(B_d.shape()) - print(L_diagonal_blocks_d.shape()) + print(B_d.shape) + print(L_diagonal_blocks_d.shape) B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize], From ec0a01b0fc037f4eab2f751808ce5b0ab2abc71e Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 14:45:30 +0000 Subject: [PATCH 017/518] compare B and L --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 07e01c89..a2dca781 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -325,8 +325,8 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_diagonal_events[i % 2]) compute_stream.wait_event(compute_arrow_B_events[i % 2]) compute_stream.wait_event(compute_current_B_events[(i + 1) % 2]) - print(B_d.shape) - print(L_diagonal_blocks_d.shape) + print(B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize].shape) + print(L_diagonal_blocks_d[i % 2].shape) B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize], From 6a657cc6fac80f3ec1e1968a3741a8ff003eabb9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 14:51:36 +0000 Subject: [PATCH 018/518] changed B slice in 1 --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index a2dca781..7562555c 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -325,7 +325,7 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_diagonal_events[i % 2]) compute_stream.wait_event(compute_arrow_B_events[i % 2]) compute_stream.wait_event(compute_current_B_events[(i + 1) % 2]) - print(B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize].shape) + print(B_d[i % 2].shape) print(L_diagonal_blocks_d[i % 2].shape) B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], From c3bb244ce94fb56ddcc49603ba9614ce93327652 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 14:52:35 +0000 Subject: [PATCH 019/518] changed actual B slices --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 7562555c..008d99fc 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -327,9 +327,9 @@ def _pobtas_streaming( compute_stream.wait_event(compute_current_B_events[(i + 1) % 2]) print(B_d[i % 2].shape) print(L_diagonal_blocks_d[i % 2].shape) - B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize] = cu_la.solve_triangular( + B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], - B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize], + B_d[i % 2], lower=True, ) compute_current_B_events[i % 2].record(stream=compute_stream) From 11a9cc63b7af57d08f333910a0f7bc9ccce1b952 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 14:54:31 +0000 Subject: [PATCH 020/518] changed further B slice --- src/serinv/algs/pobtas.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 008d99fc..6c1a259f 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -325,8 +325,6 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_diagonal_events[i % 2]) compute_stream.wait_event(compute_arrow_B_events[i % 2]) compute_stream.wait_event(compute_current_B_events[(i + 1) % 2]) - print(B_d[i % 2].shape) - print(L_diagonal_blocks_d[i % 2].shape) B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2], @@ -352,9 +350,9 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_B_events[(i + 1) % 2]) compute_stream.wait_event(compute_current_B_events[i % 2]) compute_stream.wait_event(compute_next_B_events[(i + 1) % 2]) - B_d[(i + 1) % 2 * diag_blocksize : (i + 2) % 2 * diag_blocksize] -= ( - L_lower_diagonal_blocks[i%2] - @ B[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize] + B_d[(i + 1) % 2] -= ( + L_lower_diagonal_blocks[i % 2] + @ B[i % 2] ) compute_next_B_events[i % 2].record(stream=compute_stream) From 935848de17603d7bb4f91b2e00093db143157cd3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 15:00:16 +0000 Subject: [PATCH 021/518] fixed typos --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 6c1a259f..4647eeda 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -351,8 +351,8 @@ def _pobtas_streaming( compute_stream.wait_event(compute_current_B_events[i % 2]) compute_stream.wait_event(compute_next_B_events[(i + 1) % 2]) B_d[(i + 1) % 2] -= ( - L_lower_diagonal_blocks[i % 2] - @ B[i % 2] + L_lower_diagonal_blocks_d[i % 2] + @ B_d[i % 2] ) compute_next_B_events[i % 2].record(stream=compute_stream) From f527a69a5bfc1b9041c8db5b7cb79f96119ff2c5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 15:01:36 +0000 Subject: [PATCH 022/518] changed last B slice --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 4647eeda..45271abd 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -367,7 +367,7 @@ def _pobtas_streaming( compute_stream.wait_event(compute_next_B_events[i % 2]) B_last_block_d -= ( L_lower_arrow_blocks_d[i % 2] - @ B_d[i % 2 * diag_blocksize : (i + 1) % 2 * diag_blocksize] + @ B_d[i % 2] ) compute_arrow_B_events[i % 2].record(stream=compute_stream) From 69fc9a064f66f81f859e421e2bd559532f4942de Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 15:17:24 +0000 Subject: [PATCH 023/518] changed index for lower diag blocks --- src/serinv/algs/pobtas.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 45271abd..1420c97e 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -302,8 +302,8 @@ def _pobtas_streaming( L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream) h2d_diagonal_events[0].record(stream=h2d_stream) - L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) - h2d_lower_diagonal_events[0].record(stream=h2d_stream) + #L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) + #h2d_lower_diagonal_events[0].record(stream=h2d_stream) L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[0], stream=h2d_stream) h2d_arrow_events[0].record(stream=h2d_stream) @@ -318,7 +318,7 @@ def _pobtas_streaming( if trans == "N": - for i in range(0, n_diag_blocks-1): + for i in range(0, n_diag_blocks - 1): # --- Forward substitution --- with compute_stream: # Compute step 1 : compute B @@ -357,7 +357,7 @@ def _pobtas_streaming( compute_next_B_events[i % 2].record(stream=compute_stream) h2d_stream.wait_event(compute_next_B_events[i % 2]) - L_lower_diagonal_blocks_d[(i + 2) % 2].set(arr=L_lower_diagonal_blocks[i + 2], stream=h2d_stream) + L_lower_diagonal_blocks_d[(i + 1) % 2].set(arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream) h2d_lower_diagonal_events[i % 2].record(stream=h2d_stream) with compute_stream: From 467ce64abfd2bd81cc79785e21860757ade295f2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 15:18:54 +0000 Subject: [PATCH 024/518] changed index for diagonal blocks --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 1420c97e..7284ee28 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -333,7 +333,7 @@ def _pobtas_streaming( compute_current_B_events[i % 2].record(stream=compute_stream) h2d_stream.wait_event(compute_current_B_events[i % 2]) - L_diagonal_blocks_d[(i + 2) % 2].set(arr=L_diagonal_blocks[i + 2], stream=h2d_stream) + L_diagonal_blocks_d[(i + 1) % 2].set(arr=L_diagonal_blocks[i + 1], stream=h2d_stream) h2d_diagonal_events[i % 2].record(stream=h2d_stream) d2h_stream.wait_event(compute_next_B_events[i % 2]) From aa5c893c3fde04cc120eb55dcca23c761746a32e Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 15:32:28 +0000 Subject: [PATCH 025/518] inserted ifs for termination --- src/serinv/algs/pobtas.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 7284ee28..27e1412b 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -313,9 +313,18 @@ def _pobtas_streaming( n_diag_blocks: int = L_diagonal_blocks.shape[0] # why? if n_diag_blocks > 1: + B_d[1].set(arr=B[1 * diag_blocksize : 2 * diag_blocksize], stream = h2d_stream) + h2d_B_events[1].record(stream=h2d_stream) + L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) h2d_lower_diagonal_events[0].record(stream=h2d_stream) + L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) + h2d_diagonal_events[1].record(stream=h2d_stream) + + L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream) + h2d_lower_diagonal_events[1].record(stream=h2d_stream) + if trans == "N": for i in range(0, n_diag_blocks - 1): @@ -332,9 +341,10 @@ def _pobtas_streaming( ) compute_current_B_events[i % 2].record(stream=compute_stream) - h2d_stream.wait_event(compute_current_B_events[i % 2]) - L_diagonal_blocks_d[(i + 1) % 2].set(arr=L_diagonal_blocks[i + 1], stream=h2d_stream) - h2d_diagonal_events[i % 2].record(stream=h2d_stream) + if i + 2 < n_diag_blocks - 1: + h2d_stream.wait_event(compute_current_B_events[i % 2]) + L_diagonal_blocks_d[(i + 2) % 2].set(arr=L_diagonal_blocks[i + 2], stream=h2d_stream) + h2d_diagonal_events[(i + 2) % 2].record(stream=h2d_stream) d2h_stream.wait_event(compute_next_B_events[i % 2]) B_d[i % 2].get( @@ -356,9 +366,10 @@ def _pobtas_streaming( ) compute_next_B_events[i % 2].record(stream=compute_stream) - h2d_stream.wait_event(compute_next_B_events[i % 2]) - L_lower_diagonal_blocks_d[(i + 1) % 2].set(arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream) - h2d_lower_diagonal_events[i % 2].record(stream=h2d_stream) + if i + 2 < n_diag_blocks - 1: + h2d_stream.wait_event(compute_next_B_events[i % 2]) + L_lower_diagonal_blocks_d[(i + 2) % 2].set(arr=L_lower_diagonal_blocks[i + 2], stream=h2d_stream) + h2d_lower_diagonal_events[(i + 2) % 2].record(stream=h2d_stream) with compute_stream: # 3 @@ -371,12 +382,13 @@ def _pobtas_streaming( ) compute_arrow_B_events[i % 2].record(stream=compute_stream) - h2d_stream.wait_event(compute_arrow_B_events[i % 2]) - B_d[(i + 2) % 2].set(arr=B[(i + 2) * diag_blocksize : (i + 3) * diag_blocksize], stream = h2d_stream) - h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) + if i + 2 < n_diag_blocks - 1: + h2d_stream.wait_event(compute_arrow_B_events[i % 2]) + B_d[(i + 2) % 2].set(arr=B[(i + 2) * diag_blocksize : (i + 3) * diag_blocksize], stream = h2d_stream) + h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) - L_lower_arrow_blocks_d[(i + 1) % 2].set(arr=L_lower_arrow_blocks[i + 1], stream=h2d_stream) - h2d_arrow_events[i % 2].record(stream=h2d_stream) + L_lower_arrow_blocks_d[(i + 1) % 2].set(arr=L_lower_arrow_blocks[i + 1], stream=h2d_stream) + h2d_arrow_events[(i + 1) % 2].record(stream=h2d_stream) if not partial: From 5ab5cf902b457740ac79d42ba359c3ca4b35b10e Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 15:33:53 +0000 Subject: [PATCH 026/518] fixed typo --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 27e1412b..ad4ca3de 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -322,7 +322,7 @@ def _pobtas_streaming( L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) h2d_diagonal_events[1].record(stream=h2d_stream) - L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream) + L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[1], stream=h2d_stream) h2d_lower_diagonal_events[1].record(stream=h2d_stream) From 069f355faa868d532b39e25d493ff791044bb49b Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 15:35:59 +0000 Subject: [PATCH 027/518] fixed typo --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index ad4ca3de..5972ea1c 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -410,7 +410,7 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_arrow_events[0]) compute_stream.wait_event(compute_partial_events[0]) - B_last_block_d -= (L_lower_arrow_blocks_d[-1] @ B_last_block_d[1]) + B_last_block_d -= (L_lower_arrow_blocks_d[1] @ B_last_block_d[1]) compute_partial_events[1].record(stream=compute_stream) d2h_stream.wait_event(compute_partial_events[1]) From e15278663d7a811d3cb3f322ef0d5b57b9c9bd10 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 15:37:59 +0000 Subject: [PATCH 028/518] insert debug prints --- src/serinv/algs/pobtas.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 5972ea1c..38dbc43a 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -410,6 +410,8 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_arrow_events[0]) compute_stream.wait_event(compute_partial_events[0]) + print(B_last_block_d.shape) + print(L-L_lower_arrow_blocks_d.shape) B_last_block_d -= (L_lower_arrow_blocks_d[1] @ B_last_block_d[1]) compute_partial_events[1].record(stream=compute_stream) From f932f67e68f2c9f3bab443053f5a0b9202312061 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 15:38:33 +0000 Subject: [PATCH 029/518] typo --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 38dbc43a..f0310332 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -411,7 +411,7 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_arrow_events[0]) compute_stream.wait_event(compute_partial_events[0]) print(B_last_block_d.shape) - print(L-L_lower_arrow_blocks_d.shape) + print(L_lower_arrow_blocks_d.shape) B_last_block_d -= (L_lower_arrow_blocks_d[1] @ B_last_block_d[1]) compute_partial_events[1].record(stream=compute_stream) From 730de95f43c3115f2618cadabe4b49e249e58450 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 15:40:09 +0000 Subject: [PATCH 030/518] changed b last block --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index f0310332..55e089c8 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -412,7 +412,7 @@ def _pobtas_streaming( compute_stream.wait_event(compute_partial_events[0]) print(B_last_block_d.shape) print(L_lower_arrow_blocks_d.shape) - B_last_block_d -= (L_lower_arrow_blocks_d[1] @ B_last_block_d[1]) + B_last_block_d -= (L_lower_arrow_blocks_d[1] @ B_last_block_d) compute_partial_events[1].record(stream=compute_stream) d2h_stream.wait_event(compute_partial_events[1]) From b61f8644df2b750ed666ffde0941fe5ed0775740 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 15:43:41 +0000 Subject: [PATCH 031/518] fixed lower arrow blocks in partial --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 55e089c8..dd2dd6a1 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -398,7 +398,7 @@ def _pobtas_streaming( L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream) h2d_diagonal_events[0].record(stream=h2d_stream) - L_lower_arrow_blocks_d[0].set(arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream) + L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) h2d_arrow_events[0].record(stream=h2d_stream) From 81a063f610b6b76d51228b57d2d90a81157ef6bd Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 15:48:33 +0000 Subject: [PATCH 032/518] changed typo --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index dd2dd6a1..592000b4 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -416,7 +416,7 @@ def _pobtas_streaming( compute_partial_events[1].record(stream=compute_stream) d2h_stream.wait_event(compute_partial_events[1]) - B_d[i % 2].get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + B_last_block_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) # Y_{ndb+1} = L_{ndb+1,ndb+1}^{-1} (B_{ndb+1} - \Sigma_{i=1}^{ndb} L_{ndb+1,i} Y_{i) From 51acd9aea4bbae5be647ccf9e9b38e013ae9bfe2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 15:49:49 +0000 Subject: [PATCH 033/518] changed test --- tests/tests_algs/regular/tests_bta/test_pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index cbb7aaed..b6bf238d 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -22,7 +22,7 @@ def test_pobtas( array_type: str, dtype: np.dtype, ): - array_type = "streaming" + #array_type = "streaming" A = dd_bta( diagonal_blocksize, From 66a23f8389c0420de35995bd3dee995b3d89d0c6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 24 Apr 2025 15:52:25 +0000 Subject: [PATCH 034/518] new debug print --- src/serinv/algs/pobtas.py | 1 + tests/tests_algs/regular/tests_bta/test_pobtas.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 592000b4..95bc5552 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -412,6 +412,7 @@ def _pobtas_streaming( compute_stream.wait_event(compute_partial_events[0]) print(B_last_block_d.shape) print(L_lower_arrow_blocks_d.shape) + print(L_lower_arrow_blocks_d[1] @ B_last_block_d) B_last_block_d -= (L_lower_arrow_blocks_d[1] @ B_last_block_d) compute_partial_events[1].record(stream=compute_stream) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index b6bf238d..cbb7aaed 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -22,7 +22,7 @@ def test_pobtas( array_type: str, dtype: np.dtype, ): - #array_type = "streaming" + array_type = "streaming" A = dd_bta( diagonal_blocksize, From 7a456a58ca15f606a61c54f998ecaa1a43e61ef8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 13:26:44 +0000 Subject: [PATCH 035/518] changed logic to accomodate arrow sizes --- src/serinv/algs/pobtas.py | 84 ++++++++++++++++++++++++++++++++------- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 95bc5552..f289ced2 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -250,8 +250,10 @@ def _pobtas_streaming( h2d_lower_diagonal_events = [cp.cuda.Event(), cp.cuda.Event()] h2d_arrow_events = [cp.cuda.Event(), cp.cuda.Event()] h2d_B_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_tip_events = [cp.cuda.Event(), cp.cuda.Event()] d2h_B_events = [cp.cuda.Event(), cp.cuda.Event()] + d2h_tip_events = [cp.cuda.Event(), cp.cuda.Event()] compute_current_B_events = [cp.cuda.Event(), cp.cuda.Event()] compute_next_B_events = [cp.cuda.Event(), cp.cuda.Event()] @@ -271,7 +273,7 @@ def _pobtas_streaming( (2, *B_shape.shape), dtype=B_shape.dtype ) B_shape = B[-arrow_blocksize:] - B_last_block_d = cp.empty_like(B_shape) + B_arrow_tip_d = cp.empty_like(B_shape) del B_shape # L Buffers @@ -292,7 +294,7 @@ def _pobtas_streaming( compute_next_B_events[1].record(stream=compute_stream) compute_arrow_B_events[1].record(stream=compute_stream) - B_last_block_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) + B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:, :], stream=h2d_stream) # --- H2D: transfers --- @@ -310,6 +312,7 @@ def _pobtas_streaming( # --- D2H: event --- d2h_B_events[1].record(stream=d2h_stream) + n_diag_blocks: int = L_diagonal_blocks.shape[0] # why? if n_diag_blocks > 1: @@ -345,14 +348,18 @@ def _pobtas_streaming( h2d_stream.wait_event(compute_current_B_events[i % 2]) L_diagonal_blocks_d[(i + 2) % 2].set(arr=L_diagonal_blocks[i + 2], stream=h2d_stream) h2d_diagonal_events[(i + 2) % 2].record(stream=h2d_stream) + if not ((i + 2) * diag_blocksize) < (n_diag_blocks * diag_blocksize - arrow_blocksize): + B_d[(i + 1) % 2].set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream,) - d2h_stream.wait_event(compute_next_B_events[i % 2]) + d2h_stream.wait_event(compute_current_B_events[i % 2]) B_d[i % 2].get( out=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=d2h_stream, blocking=False, ) d2h_B_events[i % 2].record(stream=d2h_stream) + + with compute_stream: # 2 @@ -370,22 +377,54 @@ def _pobtas_streaming( h2d_stream.wait_event(compute_next_B_events[i % 2]) L_lower_diagonal_blocks_d[(i + 2) % 2].set(arr=L_lower_diagonal_blocks[i + 2], stream=h2d_stream) h2d_lower_diagonal_events[(i + 2) % 2].record(stream=h2d_stream) + + if not ((i + 2) * diag_blocksize) < (n_diag_blocks * diag_blocksize - arrow_blocksize): + d2h_stream.wait_event(compute_next_B_events[i % 2]) + B_d[(i + 1) % 2].get( + out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream=d2h_stream, + blocking=False, + ) + d2h_B_events[(i + 1) % 2].record(stream=d2h_stream) + + h2d_stream.wait_event(d2h_B_events[(i + 1) % 2]) + B_arrow_tip_d.set(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + h2d_tip_events[i % 2].record(stream=h2d_stream) + with compute_stream: # 3 compute_stream.wait_event(h2d_arrow_events[i % 2]) compute_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) compute_stream.wait_event(compute_next_B_events[i % 2]) - B_last_block_d -= ( + if not ((i + 2) * diag_blocksize) < (n_diag_blocks * diag_blocksize - arrow_blocksize): + compute_stream.wait_event(h2d_tip_events[i % 2]) + + B_arrow_tip_d -= ( L_lower_arrow_blocks_d[i % 2] @ B_d[i % 2] ) + compute_arrow_B_events[i % 2].record(stream=compute_stream) + - if i + 2 < n_diag_blocks - 1: + # make sure that arrowtip and B overlap gets resolved + if ((i + 3) * diag_blocksize) < (n_diag_blocks * diag_blocksize - arrow_blocksize): h2d_stream.wait_event(compute_arrow_B_events[i % 2]) B_d[(i + 2) % 2].set(arr=B[(i + 2) * diag_blocksize : (i + 3) * diag_blocksize], stream = h2d_stream) - h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) + h2d_B_events[(i + 2) % 2].record(stream=h2d_stream) + + else: + d2h_stream.wait_event(compute_arrow_B_events[i % 2]) + B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + d2h_tip_events[i % 2].record(stream=d2h_stream) + + if i + 1 < n_diag_blocks - 1: + B_d[(i + 1) % 2].set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream = h2d_stream) + h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) + + + if i + 2 < n_diag_blocks - 1: L_lower_arrow_blocks_d[(i + 1) % 2].set(arr=L_lower_arrow_blocks[i + 1], stream=h2d_stream) h2d_arrow_events[(i + 1) % 2].record(stream=h2d_stream) @@ -395,31 +434,48 @@ def _pobtas_streaming( # In the case of the partial solve, we do not solve the last block and # arrow tip block of the RHS. + h2d_stream.wait_event(d2h_tip_events[n_diag_blocks % 2]) L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream) h2d_diagonal_events[0].record(stream=h2d_stream) + B_d[0].set(arr=B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize], stream=h2d_stream,) + h2d_B_events[0].record(stream=h2d_stream) + L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) h2d_arrow_events[0].record(stream=h2d_stream) + with compute_stream: compute_stream.wait_event(h2d_diagonal_events[0]) - B_last_block_d = (cu_la.solve_triangular(L_diagonal_blocks_d[0], B_d[0], lower=True,)) + compute_stream.wait_event(h2d_B_events[0]) + B_d = (cu_la.solve_triangular(L_diagonal_blocks_d[0], B_d[0], lower=True,)) compute_partial_events[0].record(stream=compute_stream) + d2h_stream.wait_event(compute_partial_events[0]) + B_d[0].get(out=B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize], stream=d2h_stream, blocking=False,) + d2h_B_events[0].record(stream=d2h_stream) + + h2d_stream.wait_event(d2h_B_events[0]) + B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + h2d_tip_events[0].record(stream=h2d_stream) + + with compute_stream: compute_stream.wait_event(h2d_arrow_events[0]) + compute_stream.wait_event(h2d_tip_events[0]) compute_stream.wait_event(compute_partial_events[0]) - print(B_last_block_d.shape) - print(L_lower_arrow_blocks_d.shape) - print(L_lower_arrow_blocks_d[1] @ B_last_block_d) - B_last_block_d -= (L_lower_arrow_blocks_d[1] @ B_last_block_d) + + B_arrow_tip_d -= (L_lower_arrow_blocks_d[1] @ B_arrow_tip_d) compute_partial_events[1].record(stream=compute_stream) - d2h_stream.wait_event(compute_partial_events[1]) - B_last_block_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + compute_stream.wait_event(compute_partial_events[1]) + B_arrow_tip_d = cu_la.solve_triangular(L_arrow_tip_block_d, B_arrow_tip_d, lower=True) + compute_partial_events[0].record(stream=compute_stream) + + d2h_stream.wait_event(compute_partial_events[0]) + B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - # Y_{ndb+1} = L_{ndb+1,ndb+1}^{-1} (B_{ndb+1} - \Sigma_{i=1}^{ndb} L_{ndb+1,i} Y_{i) elif trans == "T" or trans == "C": # ----- Backward substitution ----- From b7fa179c9ba859daef544eb06f117dafadf17cd4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 13:29:18 +0000 Subject: [PATCH 036/518] typo --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index f289ced2..1219e182 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -388,7 +388,7 @@ def _pobtas_streaming( d2h_B_events[(i + 1) % 2].record(stream=d2h_stream) h2d_stream.wait_event(d2h_B_events[(i + 1) % 2]) - B_arrow_tip_d.set(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) h2d_tip_events[i % 2].record(stream=h2d_stream) From 1d550cd6ff091b701606cd6908c6919c511d11be Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 13:29:53 +0000 Subject: [PATCH 037/518] fixed function --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 1219e182..df1d4814 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -388,7 +388,7 @@ def _pobtas_streaming( d2h_B_events[(i + 1) % 2].record(stream=d2h_stream) h2d_stream.wait_event(d2h_B_events[(i + 1) % 2]) - B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=d2h_stream,) h2d_tip_events[i % 2].record(stream=h2d_stream) From 8c221e880c0574e7efd15c85c5da1e67857b2901 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 13:32:25 +0000 Subject: [PATCH 038/518] insert debug statements --- src/serinv/algs/pobtas.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index df1d4814..33847c30 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -454,6 +454,8 @@ def _pobtas_streaming( compute_partial_events[0].record(stream=compute_stream) d2h_stream.wait_event(compute_partial_events[0]) + print(B_d[0]) + print(B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize]) B_d[0].get(out=B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize], stream=d2h_stream, blocking=False,) d2h_B_events[0].record(stream=d2h_stream) From 380340e2374adc18320757329951eb31193ed9e6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 13:35:39 +0000 Subject: [PATCH 039/518] more debugging --- src/serinv/algs/pobtas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 33847c30..861952f8 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -439,6 +439,7 @@ def _pobtas_streaming( h2d_diagonal_events[0].record(stream=h2d_stream) B_d[0].set(arr=B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize], stream=h2d_stream,) + print(B_d[0]) h2d_B_events[0].record(stream=h2d_stream) L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) From d41e4e916a786dfba86c69f702ba46e144b3dca3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 13:39:47 +0000 Subject: [PATCH 040/518] debugging second to last solve --- src/serinv/algs/pobtas.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 861952f8..b767b300 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -445,7 +445,18 @@ def _pobtas_streaming( L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) h2d_arrow_events[0].record(stream=h2d_stream) - + B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize] = ( + la.solve_triangular( + L_diagonal_blocks[n_diag_blocks - 1], + B[ + (n_diag_blocks - 1) + * diag_blocksize : n_diag_blocks + * diag_blocksize + ], + lower=True, + ) + ) + print(B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize]) with compute_stream: @@ -461,7 +472,7 @@ def _pobtas_streaming( d2h_B_events[0].record(stream=d2h_stream) h2d_stream.wait_event(d2h_B_events[0]) - B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) h2d_tip_events[0].record(stream=h2d_stream) with compute_stream: From 7f8ae975cbe78cc618443576f32205e38306b7e7 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 13:41:53 +0000 Subject: [PATCH 041/518] fixed second to last solve --- src/serinv/algs/pobtas.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index b767b300..9b952100 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -445,24 +445,12 @@ def _pobtas_streaming( L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) h2d_arrow_events[0].record(stream=h2d_stream) - B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize] = ( - la.solve_triangular( - L_diagonal_blocks[n_diag_blocks - 1], - B[ - (n_diag_blocks - 1) - * diag_blocksize : n_diag_blocks - * diag_blocksize - ], - lower=True, - ) - ) - print(B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize]) with compute_stream: compute_stream.wait_event(h2d_diagonal_events[0]) compute_stream.wait_event(h2d_B_events[0]) - B_d = (cu_la.solve_triangular(L_diagonal_blocks_d[0], B_d[0], lower=True,)) + B_d[0] = (cu_la.solve_triangular(L_diagonal_blocks_d[0], B_d[0], lower=True,)) compute_partial_events[0].record(stream=compute_stream) d2h_stream.wait_event(compute_partial_events[0]) From c2b52aaa6009167238355655671658d4b0834924 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 13:42:53 +0000 Subject: [PATCH 042/518] removed debugging statements --- src/serinv/algs/pobtas.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 9b952100..4d0bbf08 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -439,7 +439,6 @@ def _pobtas_streaming( h2d_diagonal_events[0].record(stream=h2d_stream) B_d[0].set(arr=B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize], stream=h2d_stream,) - print(B_d[0]) h2d_B_events[0].record(stream=h2d_stream) L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) @@ -454,13 +453,11 @@ def _pobtas_streaming( compute_partial_events[0].record(stream=compute_stream) d2h_stream.wait_event(compute_partial_events[0]) - print(B_d[0]) - print(B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize]) B_d[0].get(out=B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize], stream=d2h_stream, blocking=False,) d2h_B_events[0].record(stream=d2h_stream) h2d_stream.wait_event(d2h_B_events[0]) - B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=d2h_stream,) h2d_tip_events[0].record(stream=h2d_stream) with compute_stream: From 11f838ecca8400e2f101094aa99f0ba2fecbe989 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 13:52:28 +0000 Subject: [PATCH 043/518] insert debug statements --- src/serinv/algs/pobtas.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 4d0bbf08..c566faa0 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -465,6 +465,9 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_tip_events[0]) compute_stream.wait_event(compute_partial_events[0]) + print(L_lower_arrow_blocks_d[1]) + print(B_arrow_tip_d) + B_arrow_tip_d -= (L_lower_arrow_blocks_d[1] @ B_arrow_tip_d) compute_partial_events[1].record(stream=compute_stream) From 133d0067825f14d9cadffb62a9b9b693950cf753 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 13:55:50 +0000 Subject: [PATCH 044/518] fixed index typo --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index c566faa0..9dc7620f 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -465,10 +465,10 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_tip_events[0]) compute_stream.wait_event(compute_partial_events[0]) - print(L_lower_arrow_blocks_d[1]) + print(L_lower_arrow_blocks_d[0]) print(B_arrow_tip_d) - B_arrow_tip_d -= (L_lower_arrow_blocks_d[1] @ B_arrow_tip_d) + B_arrow_tip_d -= (L_lower_arrow_blocks_d[0] @ B_arrow_tip_d) compute_partial_events[1].record(stream=compute_stream) compute_stream.wait_event(compute_partial_events[1]) From 2ab30872d952a3232e73adafc2674fce87540575 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 13:58:14 +0000 Subject: [PATCH 045/518] changed debug statement --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 9dc7620f..bce926ba 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -465,7 +465,7 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_tip_events[0]) compute_stream.wait_event(compute_partial_events[0]) - print(L_lower_arrow_blocks_d[0]) + print(L_lower_arrow_blocks_d) print(B_arrow_tip_d) B_arrow_tip_d -= (L_lower_arrow_blocks_d[0] @ B_arrow_tip_d) From aeca9e1d9d0b71458abd857ec7be704bdca7e693 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 14:01:18 +0000 Subject: [PATCH 046/518] changed operation order --- src/serinv/algs/pobtas.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index bce926ba..f0ea1c06 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -468,11 +468,12 @@ def _pobtas_streaming( print(L_lower_arrow_blocks_d) print(B_arrow_tip_d) - B_arrow_tip_d -= (L_lower_arrow_blocks_d[0] @ B_arrow_tip_d) + + B_arrow_tip_d = cu_la.solve_triangular(L_arrow_tip_block_d, B_arrow_tip_d, lower=True) compute_partial_events[1].record(stream=compute_stream) compute_stream.wait_event(compute_partial_events[1]) - B_arrow_tip_d = cu_la.solve_triangular(L_arrow_tip_block_d, B_arrow_tip_d, lower=True) + B_arrow_tip_d -= (L_lower_arrow_blocks_d[0] @ B_arrow_tip_d) compute_partial_events[0].record(stream=compute_stream) d2h_stream.wait_event(compute_partial_events[0]) From 78e8e25c38bd705b52eecf4a25291d53bead75ba Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 14:08:23 +0000 Subject: [PATCH 047/518] changed to right B --- src/serinv/algs/pobtas.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index f0ea1c06..f9f40b15 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -468,12 +468,11 @@ def _pobtas_streaming( print(L_lower_arrow_blocks_d) print(B_arrow_tip_d) - - B_arrow_tip_d = cu_la.solve_triangular(L_arrow_tip_block_d, B_arrow_tip_d, lower=True) + B_arrow_tip_d -= (L_lower_arrow_blocks_d[0] @ B_d[0]) compute_partial_events[1].record(stream=compute_stream) compute_stream.wait_event(compute_partial_events[1]) - B_arrow_tip_d -= (L_lower_arrow_blocks_d[0] @ B_arrow_tip_d) + B_arrow_tip_d = cu_la.solve_triangular(L_arrow_tip_block_d, B_arrow_tip_d, lower=True) compute_partial_events[0].record(stream=compute_stream) d2h_stream.wait_event(compute_partial_events[0]) From f8f5b6444d5a0e9b93a59bcffbadf561bfedc3fc Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 14:13:52 +0000 Subject: [PATCH 048/518] setup corrected for out of bounds --- src/serinv/algs/pobtas.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index f9f40b15..5d940ff9 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -325,8 +325,9 @@ def _pobtas_streaming( L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) h2d_diagonal_events[1].record(stream=h2d_stream) - L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[1], stream=h2d_stream) - h2d_lower_diagonal_events[1].record(stream=h2d_stream) + if n_diag_blocks > 2: + L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[1], stream=h2d_stream) + h2d_lower_diagonal_events[1].record(stream=h2d_stream) if trans == "N": From 22c829ac7b3b29fff0909b81cffd06c674b6084e Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 14:15:02 +0000 Subject: [PATCH 049/518] removed debug statements --- src/serinv/algs/pobtas.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 5d940ff9..53a4f14a 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -466,9 +466,6 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_tip_events[0]) compute_stream.wait_event(compute_partial_events[0]) - print(L_lower_arrow_blocks_d) - print(B_arrow_tip_d) - B_arrow_tip_d -= (L_lower_arrow_blocks_d[0] @ B_d[0]) compute_partial_events[1].record(stream=compute_stream) From 79d78f05c653bc00575733a795ad8c13fe7f8ecd Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 14:19:06 +0000 Subject: [PATCH 050/518] insert debug statement --- tests/tests_algs/regular/tests_bta/test_pobtas.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index cbb7aaed..af1daffc 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -22,7 +22,7 @@ def test_pobtas( array_type: str, dtype: np.dtype, ): - array_type = "streaming" + # array_type = "streaming" A = dd_bta( diagonal_blocksize, @@ -57,6 +57,7 @@ def test_pobtas( ) = bta_dense_to_arrays(A, diagonal_blocksize, arrowhead_blocksize, n_diag_blocks) if backend_flags["cupy_avail"] and array_type == "streaming": + print("streaming") A_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_diagonal_blocks) A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks[:, :, :] A_lower_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_lower_diagonal_blocks) From f00a04a4c5ac9086b3bdaebe36e9424ee38bb7d5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 14:28:46 +0000 Subject: [PATCH 051/518] forced streaming in tests again --- tests/tests_algs/regular/tests_bta/test_pobtas.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index af1daffc..cbb7aaed 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -22,7 +22,7 @@ def test_pobtas( array_type: str, dtype: np.dtype, ): - # array_type = "streaming" + array_type = "streaming" A = dd_bta( diagonal_blocksize, @@ -57,7 +57,6 @@ def test_pobtas( ) = bta_dense_to_arrays(A, diagonal_blocksize, arrowhead_blocksize, n_diag_blocks) if backend_flags["cupy_avail"] and array_type == "streaming": - print("streaming") A_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_diagonal_blocks) A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks[:, :, :] A_lower_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_lower_diagonal_blocks) From cccd82d45bb6ae00ef866fc49252b9494db56712 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 15:23:36 +0000 Subject: [PATCH 052/518] force streaming in pobtaf for testing --- tests/tests_algs/regular/tests_bta/test_pobtaf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtaf.py b/tests/tests_algs/regular/tests_bta/test_pobtaf.py index a30b9094..ab2e306c 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtaf.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtaf.py @@ -20,6 +20,8 @@ def test_pobtaf( array_type: str, dtype: np.dtype, ): + array_type = "streaming" + A = dd_bta( diagonal_blocksize, arrowhead_blocksize, From 416b0aabece8bf3447eaa17545689b0d1ca55ae5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 25 Apr 2025 15:26:05 +0000 Subject: [PATCH 053/518] removed forced streaming from pobtaf --- tests/tests_algs/regular/tests_bta/test_pobtaf.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtaf.py b/tests/tests_algs/regular/tests_bta/test_pobtaf.py index ab2e306c..a30b9094 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtaf.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtaf.py @@ -20,8 +20,6 @@ def test_pobtaf( array_type: str, dtype: np.dtype, ): - array_type = "streaming" - A = dd_bta( diagonal_blocksize, arrowhead_blocksize, From e943999f957ebbf89eb2bcf45b983999b097f749 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 07:31:17 +0000 Subject: [PATCH 054/518] changed stream timing --- src/serinv/algs/pobtas.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 53a4f14a..478b3765 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -336,8 +336,6 @@ def _pobtas_streaming( with compute_stream: # Compute step 1 : compute B compute_stream.wait_event(h2d_diagonal_events[i % 2]) - compute_stream.wait_event(compute_arrow_B_events[i % 2]) - compute_stream.wait_event(compute_current_B_events[(i + 1) % 2]) B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2], @@ -364,10 +362,7 @@ def _pobtas_streaming( with compute_stream: # 2 - compute_stream.wait_event(h2d_lower_diagonal_events[i % 2]) compute_stream.wait_event(h2d_B_events[(i + 1) % 2]) - compute_stream.wait_event(compute_current_B_events[i % 2]) - compute_stream.wait_event(compute_next_B_events[(i + 1) % 2]) B_d[(i + 1) % 2] -= ( L_lower_diagonal_blocks_d[i % 2] @ B_d[i % 2] @@ -396,8 +391,6 @@ def _pobtas_streaming( with compute_stream: # 3 compute_stream.wait_event(h2d_arrow_events[i % 2]) - compute_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) - compute_stream.wait_event(compute_next_B_events[i % 2]) if not ((i + 2) * diag_blocksize) < (n_diag_blocks * diag_blocksize - arrow_blocksize): compute_stream.wait_event(h2d_tip_events[i % 2]) @@ -448,7 +441,6 @@ def _pobtas_streaming( with compute_stream: - compute_stream.wait_event(h2d_diagonal_events[0]) compute_stream.wait_event(h2d_B_events[0]) B_d[0] = (cu_la.solve_triangular(L_diagonal_blocks_d[0], B_d[0], lower=True,)) compute_partial_events[0].record(stream=compute_stream) @@ -462,9 +454,7 @@ def _pobtas_streaming( h2d_tip_events[0].record(stream=h2d_stream) with compute_stream: - compute_stream.wait_event(h2d_arrow_events[0]) compute_stream.wait_event(h2d_tip_events[0]) - compute_stream.wait_event(compute_partial_events[0]) B_arrow_tip_d -= (L_lower_arrow_blocks_d[0] @ B_d[0]) compute_partial_events[1].record(stream=compute_stream) From 8a05718b5f38085deb21c8c08ccaf9857f170639 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 07:32:54 +0000 Subject: [PATCH 055/518] added sync --- src/serinv/algs/pobtas.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 478b3765..bba93898 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -480,4 +480,7 @@ def _pobtas_streaming( # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} else: - raise ValueError(f"Invalid transpose argument: {trans}.") \ No newline at end of file + raise ValueError(f"Invalid transpose argument: {trans}.") + + + cp.cuda.Device().synchronize() \ No newline at end of file From cef552cad68e416e73c101e0b9fce7c1035c858d Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 11:31:55 +0000 Subject: [PATCH 056/518] insert debug statements --- src/serinv/algs/pobtas.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index bba93898..46f201c1 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -331,6 +331,10 @@ def _pobtas_streaming( if trans == "N": + + print(B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize]) + print(B[-arrow_blocksize:]) + for i in range(0, n_diag_blocks - 1): # --- Forward substitution --- with compute_stream: From 18a8b8f0015d2e8fe48d0fd02b3b80c1cf8423a8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 11:33:41 +0000 Subject: [PATCH 057/518] insert antoher debug statement --- src/serinv/algs/pobtas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 46f201c1..40142a77 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -332,6 +332,7 @@ def _pobtas_streaming( if trans == "N": + print(B) print(B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize]) print(B[-arrow_blocksize:]) From 8a1f9f3fbdaf79b4b64aa95c93b69f13fdff0fef Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 11:40:14 +0000 Subject: [PATCH 058/518] removed misguided overlap protection --- src/serinv/algs/pobtas.py | 45 ++++++--------------------------------- 1 file changed, 7 insertions(+), 38 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 40142a77..ab9df1ea 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -331,11 +331,6 @@ def _pobtas_streaming( if trans == "N": - - print(B) - print(B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize]) - print(B[-arrow_blocksize:]) - for i in range(0, n_diag_blocks - 1): # --- Forward substitution --- with compute_stream: @@ -352,8 +347,6 @@ def _pobtas_streaming( h2d_stream.wait_event(compute_current_B_events[i % 2]) L_diagonal_blocks_d[(i + 2) % 2].set(arr=L_diagonal_blocks[i + 2], stream=h2d_stream) h2d_diagonal_events[(i + 2) % 2].record(stream=h2d_stream) - if not ((i + 2) * diag_blocksize) < (n_diag_blocks * diag_blocksize - arrow_blocksize): - B_d[(i + 1) % 2].set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream,) d2h_stream.wait_event(compute_current_B_events[i % 2]) B_d[i % 2].get( @@ -378,26 +371,10 @@ def _pobtas_streaming( h2d_stream.wait_event(compute_next_B_events[i % 2]) L_lower_diagonal_blocks_d[(i + 2) % 2].set(arr=L_lower_diagonal_blocks[i + 2], stream=h2d_stream) h2d_lower_diagonal_events[(i + 2) % 2].record(stream=h2d_stream) - - if not ((i + 2) * diag_blocksize) < (n_diag_blocks * diag_blocksize - arrow_blocksize): - d2h_stream.wait_event(compute_next_B_events[i % 2]) - B_d[(i + 1) % 2].get( - out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], - stream=d2h_stream, - blocking=False, - ) - d2h_B_events[(i + 1) % 2].record(stream=d2h_stream) - - h2d_stream.wait_event(d2h_B_events[(i + 1) % 2]) - B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=d2h_stream,) - h2d_tip_events[i % 2].record(stream=h2d_stream) - with compute_stream: # 3 compute_stream.wait_event(h2d_arrow_events[i % 2]) - if not ((i + 2) * diag_blocksize) < (n_diag_blocks * diag_blocksize - arrow_blocksize): - compute_stream.wait_event(h2d_tip_events[i % 2]) B_arrow_tip_d -= ( L_lower_arrow_blocks_d[i % 2] @@ -405,22 +382,14 @@ def _pobtas_streaming( ) compute_arrow_B_events[i % 2].record(stream=compute_stream) - - - # make sure that arrowtip and B overlap gets resolved - if ((i + 3) * diag_blocksize) < (n_diag_blocks * diag_blocksize - arrow_blocksize): - h2d_stream.wait_event(compute_arrow_B_events[i % 2]) - B_d[(i + 2) % 2].set(arr=B[(i + 2) * diag_blocksize : (i + 3) * diag_blocksize], stream = h2d_stream) - h2d_B_events[(i + 2) % 2].record(stream=h2d_stream) - else: - d2h_stream.wait_event(compute_arrow_B_events[i % 2]) - B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - d2h_tip_events[i % 2].record(stream=d2h_stream) - - if i + 1 < n_diag_blocks - 1: - B_d[(i + 1) % 2].set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream = h2d_stream) - h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) + d2h_stream.wait_event(compute_arrow_B_events[i % 2]) + B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + d2h_tip_events[i % 2].record(stream=d2h_stream) + + if i + 1 < n_diag_blocks - 1: + B_d[(i + 1) % 2].set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream = h2d_stream) + h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) if i + 2 < n_diag_blocks - 1: From 5acc905905965d1619ee4816b52a1c1034d84e53 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 12:31:38 +0000 Subject: [PATCH 059/518] changed streaming order --- src/serinv/algs/pobtas.py | 96 +++++++++++++++++++++++---------------- 1 file changed, 57 insertions(+), 39 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index ab9df1ea..a2f6799a 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -304,76 +304,104 @@ def _pobtas_streaming( L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream) h2d_diagonal_events[0].record(stream=h2d_stream) - #L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) - #h2d_lower_diagonal_events[0].record(stream=h2d_stream) - L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[0], stream=h2d_stream) h2d_arrow_events[0].record(stream=h2d_stream) # --- D2H: event --- d2h_B_events[1].record(stream=d2h_stream) + n_diag_blocks: int = L_diagonal_blocks.shape[0] - n_diag_blocks: int = L_diagonal_blocks.shape[0] # why? - if n_diag_blocks > 1: - B_d[1].set(arr=B[1 * diag_blocksize : 2 * diag_blocksize], stream = h2d_stream) - h2d_B_events[1].record(stream=h2d_stream) - - L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) - h2d_lower_diagonal_events[0].record(stream=h2d_stream) - - L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) - h2d_diagonal_events[1].record(stream=h2d_stream) + # if n_diag_blocks > 1: - if n_diag_blocks > 2: - L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[1], stream=h2d_stream) - h2d_lower_diagonal_events[1].record(stream=h2d_stream) + L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) + h2d_lower_diagonal_events[0].record(stream=h2d_stream) if trans == "N": for i in range(0, n_diag_blocks - 1): # --- Forward substitution --- + + if i + 1 < n_diag_blocks - 1: + # stream next B block + h2d_stream.wait_event(d2h_B_events[(i + 1) % 2]) + + B_d[(i + 1) % 2].set( + arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream = h2d_stream + ) + + h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) + + # stream next diagonal block + h2d_stream.wait_event(compute_current_B_events[(i + 1) % 2]) + + L_diagonal_blocks_d[(i + 1) % 2].set( + arr=L_diagonal_blocks[i + 1], + stream=h2d_stream + ) + + h2d_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) + + with compute_stream: # Compute step 1 : compute B compute_stream.wait_event(h2d_diagonal_events[i % 2]) + B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2], lower=True, ) + compute_current_B_events[i % 2].record(stream=compute_stream) - if i + 2 < n_diag_blocks - 1: - h2d_stream.wait_event(compute_current_B_events[i % 2]) - L_diagonal_blocks_d[(i + 2) % 2].set(arr=L_diagonal_blocks[i + 2], stream=h2d_stream) - h2d_diagonal_events[(i + 2) % 2].record(stream=h2d_stream) - + # stream B back d2h_stream.wait_event(compute_current_B_events[i % 2]) + B_d[i % 2].get( out=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=d2h_stream, blocking=False, ) + d2h_B_events[i % 2].record(stream=d2h_stream) - + if i + 1 < n_diag_blocks - 1: + # stream next lower diagonal block + h2d_stream.wait_event(compute_next_B_events[(i + 1) % 2]) + + L_lower_diagonal_blocks_d[(i + 1) % 2].set( + arr=L_lower_diagonal_blocks[i + 1], + stream=h2d_stream + ) + + h2d_lower_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) with compute_stream: - # 2 + # Compute step 2 : update next B compute_stream.wait_event(h2d_B_events[(i + 1) % 2]) + B_d[(i + 1) % 2] -= ( L_lower_diagonal_blocks_d[i % 2] @ B_d[i % 2] ) - compute_next_B_events[i % 2].record(stream=compute_stream) - if i + 2 < n_diag_blocks - 1: - h2d_stream.wait_event(compute_next_B_events[i % 2]) - L_lower_diagonal_blocks_d[(i + 2) % 2].set(arr=L_lower_diagonal_blocks[i + 2], stream=h2d_stream) - h2d_lower_diagonal_events[(i + 2) % 2].record(stream=h2d_stream) + compute_next_B_events[i % 2].record(stream=compute_stream) + if i + 1 < n_diag_blocks - 1: + # stream next lower arrow block + h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) + + L_lower_arrow_blocks_d[(i + 1) % 2].set( + arr=L_lower_arrow_blocks[i + 1], + stream=h2d_stream + ) + + h2d_arrow_events[(i + 1) % 2].record(stream=h2d_stream) + with compute_stream: - # 3 + # Compute step 3 : update arrowtip compute_stream.wait_event(h2d_arrow_events[i % 2]) B_arrow_tip_d -= ( @@ -386,16 +414,6 @@ def _pobtas_streaming( d2h_stream.wait_event(compute_arrow_B_events[i % 2]) B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) d2h_tip_events[i % 2].record(stream=d2h_stream) - - if i + 1 < n_diag_blocks - 1: - B_d[(i + 1) % 2].set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream = h2d_stream) - h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) - - - if i + 2 < n_diag_blocks - 1: - - L_lower_arrow_blocks_d[(i + 1) % 2].set(arr=L_lower_arrow_blocks[i + 1], stream=h2d_stream) - h2d_arrow_events[(i + 1) % 2].record(stream=h2d_stream) if not partial: From c45adc911ce8b5bd94c93279074d26cf5534497b Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 12:32:25 +0000 Subject: [PATCH 060/518] rolled back if statement --- src/serinv/algs/pobtas.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index a2f6799a..7d59cb0f 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -312,10 +312,10 @@ def _pobtas_streaming( n_diag_blocks: int = L_diagonal_blocks.shape[0] - # if n_diag_blocks > 1: + if n_diag_blocks > 1: - L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) - h2d_lower_diagonal_events[0].record(stream=h2d_stream) + L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) + h2d_lower_diagonal_events[0].record(stream=h2d_stream) if trans == "N": From 37518945528c512ff98222231fea9ffdcdda7862 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 12:37:38 +0000 Subject: [PATCH 061/518] debug statement to check if the last block is the problem --- src/serinv/algs/pobtas.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 7d59cb0f..0384545d 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -319,8 +319,8 @@ def _pobtas_streaming( if trans == "N": + # --- Forward substitution --- for i in range(0, n_diag_blocks - 1): - # --- Forward substitution --- if i + 1 < n_diag_blocks - 1: # stream next B block @@ -419,6 +419,11 @@ def _pobtas_streaming( if not partial: # In the case of the partial solve, we do not solve the last block and # arrow tip block of the RHS. + + raise NotImplementedError( + "wrong." + ) + h2d_stream.wait_event(d2h_tip_events[n_diag_blocks % 2]) L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream) @@ -464,8 +469,8 @@ def _pobtas_streaming( if not partial: # X_{ndb+1} = L_{ndb+1,ndb+1}^{-T} (Y_{ndb+1}) raise NotImplementedError( - "T and C not yet implemented." - ) + "T and C not yet implemented." + ) # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) # for i in range(n_diag_blocks -2, -1, -1): From c6f63d1c9b261bb1c79d39849d211eafe6966f00 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 12:48:24 +0000 Subject: [PATCH 062/518] changed non partial solve --- src/serinv/algs/pobtas.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 0384545d..dde50341 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -419,41 +419,29 @@ def _pobtas_streaming( if not partial: # In the case of the partial solve, we do not solve the last block and # arrow tip block of the RHS. - - raise NotImplementedError( - "wrong." - ) - h2d_stream.wait_event(d2h_tip_events[n_diag_blocks % 2]) L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream) h2d_diagonal_events[0].record(stream=h2d_stream) - B_d[0].set(arr=B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize], stream=h2d_stream,) - h2d_B_events[0].record(stream=h2d_stream) - L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) h2d_arrow_events[0].record(stream=h2d_stream) with compute_stream: - compute_stream.wait_event(h2d_B_events[0]) - B_d[0] = (cu_la.solve_triangular(L_diagonal_blocks_d[0], B_d[0], lower=True,)) + compute_stream.wait_event(h2d_diagonal_events[0]) + B_d[(n_diag_blocks - 1) % 2] = (cu_la.solve_triangular(L_diagonal_blocks_d[0], B_d[(n_diag_blocks - 1) % 2], lower=True,)) compute_partial_events[0].record(stream=compute_stream) d2h_stream.wait_event(compute_partial_events[0]) B_d[0].get(out=B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize], stream=d2h_stream, blocking=False,) d2h_B_events[0].record(stream=d2h_stream) - h2d_stream.wait_event(d2h_B_events[0]) - B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=d2h_stream,) - h2d_tip_events[0].record(stream=h2d_stream) - with compute_stream: - compute_stream.wait_event(h2d_tip_events[0]) + compute_stream.wait_event(h2d_arrow_events[0]) - B_arrow_tip_d -= (L_lower_arrow_blocks_d[0] @ B_d[0]) + B_arrow_tip_d -= (L_lower_arrow_blocks_d[0] @ B_d[(n_diag_blocks - 1) % 2]) compute_partial_events[1].record(stream=compute_stream) compute_stream.wait_event(compute_partial_events[1]) From ee1798c16e7f1bd73e74c9089389e59a8c4d692b Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 12:52:12 +0000 Subject: [PATCH 063/518] debug to see passed tests --- src/serinv/algs/pobtas.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index dde50341..7865616b 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -417,6 +417,9 @@ def _pobtas_streaming( if not partial: + raise NotImplementedError( + "just error display" + ) # In the case of the partial solve, we do not solve the last block and # arrow tip block of the RHS. From c63a5de2dc4b9f0d463dad62a67d976e87fe025c Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 12:57:35 +0000 Subject: [PATCH 064/518] inserted debug statements to compare B --- src/serinv/algs/pobtas.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 7865616b..e89c9273 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -319,6 +319,7 @@ def _pobtas_streaming( if trans == "N": + print(B) # --- Forward substitution --- for i in range(0, n_diag_blocks - 1): @@ -401,7 +402,7 @@ def _pobtas_streaming( h2d_arrow_events[(i + 1) % 2].record(stream=h2d_stream) with compute_stream: - # Compute step 3 : update arrowtip + # Compute step 3 : update arrow tip compute_stream.wait_event(h2d_arrow_events[i % 2]) B_arrow_tip_d -= ( @@ -417,9 +418,6 @@ def _pobtas_streaming( if not partial: - raise NotImplementedError( - "just error display" - ) # In the case of the partial solve, we do not solve the last block and # arrow tip block of the RHS. @@ -454,6 +452,8 @@ def _pobtas_streaming( d2h_stream.wait_event(compute_partial_events[0]) B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + print(B) + elif trans == "T" or trans == "C": # ----- Backward substitution ----- From 2083e06c77ec523c0217344af1fe6d798122269b Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 13:02:06 +0000 Subject: [PATCH 065/518] changed arrow tip block --- src/serinv/algs/pobtas.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index e89c9273..1543d657 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -295,7 +295,7 @@ def _pobtas_streaming( compute_arrow_B_events[1].record(stream=compute_stream) B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) - L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:, :], stream=h2d_stream) + L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:], stream=h2d_stream) # --- H2D: transfers --- B_d[0].set(arr=B[0 : 1 * diag_blocksize], stream = h2d_stream) @@ -319,7 +319,6 @@ def _pobtas_streaming( if trans == "N": - print(B) # --- Forward substitution --- for i in range(0, n_diag_blocks - 1): @@ -452,9 +451,6 @@ def _pobtas_streaming( d2h_stream.wait_event(compute_partial_events[0]) B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - print(B) - - elif trans == "T" or trans == "C": # ----- Backward substitution ----- if not partial: From 7805321b36b6c50e73ada04ae28e65c0ed65842e Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 13:23:38 +0000 Subject: [PATCH 066/518] changed stream timing --- src/serinv/algs/pobtas.py | 134 +++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 1543d657..96be4853 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -241,90 +241,90 @@ def _pobtas_streaming( cp, cu_la = _get_module_from_str(module_str="cupy") - # Streams and events - compute_stream = cp.cuda.Stream(non_blocking=True) - h2d_stream = cp.cuda.Stream(non_blocking=True) - d2h_stream = cp.cuda.Stream(non_blocking=True) + if trans == "N": - h2d_diagonal_events = [cp.cuda.Event(), cp.cuda.Event()] - h2d_lower_diagonal_events = [cp.cuda.Event(), cp.cuda.Event()] - h2d_arrow_events = [cp.cuda.Event(), cp.cuda.Event()] - h2d_B_events = [cp.cuda.Event(), cp.cuda.Event()] - h2d_tip_events = [cp.cuda.Event(), cp.cuda.Event()] + # Streams and events + compute_stream = cp.cuda.Stream(non_blocking=True) + h2d_stream = cp.cuda.Stream(non_blocking=True) + d2h_stream = cp.cuda.Stream(non_blocking=True) - d2h_B_events = [cp.cuda.Event(), cp.cuda.Event()] - d2h_tip_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_diagonal_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_lower_diagonal_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_arrow_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_B_events = [cp.cuda.Event(), cp.cuda.Event()] - compute_current_B_events = [cp.cuda.Event(), cp.cuda.Event()] - compute_next_B_events = [cp.cuda.Event(), cp.cuda.Event()] - compute_arrow_B_events = [cp.cuda.Event(), cp.cuda.Event()] + d2h_B_events = [cp.cuda.Event(), cp.cuda.Event()] + d2h_tip_events = [cp.cuda.Event(), cp.cuda.Event()] - compute_partial_events = [cp.cuda.Event(), cp.cuda.Event()] + compute_current_B_events = [cp.cuda.Event(), cp.cuda.Event()] + compute_next_B_events = [cp.cuda.Event(), cp.cuda.Event()] + compute_arrow_B_events = [cp.cuda.Event(), cp.cuda.Event()] - # Vars - diag_blocksize = L_diagonal_blocks.shape[1] - arrow_blocksize = L_lower_arrow_blocks.shape[1] - n_diag_blocks = L_diagonal_blocks.shape[0] + compute_partial_events = [cp.cuda.Event(), cp.cuda.Event()] - # Device Buffers - # B Buffers - B_shape = B[0 : diag_blocksize] # block template - B_d = cp.empty( - (2, *B_shape.shape), dtype=B_shape.dtype - ) - B_shape = B[-arrow_blocksize:] - B_arrow_tip_d = cp.empty_like(B_shape) - del B_shape - - # L Buffers - L_diagonal_blocks_d = cp.empty( - (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype - ) - L_lower_diagonal_blocks_d = cp.empty( - (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype - ) - L_lower_arrow_blocks_d = cp.empty( - (2, *L_lower_arrow_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype - ) - L_arrow_tip_block_d = cp.empty_like(L_arrow_tip_block) - - # Forward Pass - # --- C: events + transfers --- - compute_current_B_events[1].record(stream=compute_stream) - compute_next_B_events[1].record(stream=compute_stream) - compute_arrow_B_events[1].record(stream=compute_stream) - - B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) - L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:], stream=h2d_stream) - - # --- H2D: transfers --- - B_d[0].set(arr=B[0 : 1 * diag_blocksize], stream = h2d_stream) - h2d_B_events[0].record(stream=h2d_stream) - - L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream) - h2d_diagonal_events[0].record(stream=h2d_stream) + # Vars + diag_blocksize = L_diagonal_blocks.shape[1] + arrow_blocksize = L_lower_arrow_blocks.shape[1] + n_diag_blocks = L_diagonal_blocks.shape[0] - L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[0], stream=h2d_stream) - h2d_arrow_events[0].record(stream=h2d_stream) + # Device Buffers + # B Buffers + B_shape = B[0 : diag_blocksize] # block template + B_d = cp.empty( + (2, *B_shape.shape), dtype=B_shape.dtype + ) + B_shape = B[-arrow_blocksize:] + B_arrow_tip_d = cp.empty_like(B_shape) + del B_shape - # --- D2H: event --- - d2h_B_events[1].record(stream=d2h_stream) - - n_diag_blocks: int = L_diagonal_blocks.shape[0] + # L Buffers + L_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_lower_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_lower_arrow_blocks_d = cp.empty( + (2, *L_lower_arrow_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_arrow_tip_block_d = cp.empty_like(L_arrow_tip_block) - if n_diag_blocks > 1: + # Forward Pass + # --- C: events + transfers --- + compute_current_B_events[1].record(stream=compute_stream) + compute_next_B_events[1].record(stream=compute_stream) + compute_arrow_B_events[1].record(stream=compute_stream) - L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) - h2d_lower_diagonal_events[0].record(stream=h2d_stream) + B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) + L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:], stream=h2d_stream) + # --- H2D: transfers --- + B_d[0].set(arr=B[0 : diag_blocksize], stream = h2d_stream) + h2d_B_events[0].record(stream=h2d_stream) + + L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream) + h2d_diagonal_events[0].record(stream=h2d_stream) - if trans == "N": + L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[0], stream=h2d_stream) + h2d_arrow_events[0].record(stream=h2d_stream) + + # --- D2H: event --- + d2h_B_events[1].record(stream=d2h_stream) + + n_diag_blocks: int = L_diagonal_blocks.shape[0] + + if n_diag_blocks > 1: + + L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) + h2d_lower_diagonal_events[0].record(stream=h2d_stream) + + # --- Forward substitution --- for i in range(0, n_diag_blocks - 1): if i + 1 < n_diag_blocks - 1: # stream next B block - h2d_stream.wait_event(d2h_B_events[(i + 1) % 2]) + h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) B_d[(i + 1) % 2].set( arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], From 12157121a274b395324ee6290245817a4930505e Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 13:37:59 +0000 Subject: [PATCH 067/518] changed if to stream b + 1 --- src/serinv/algs/pobtas.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 96be4853..50156d13 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -322,7 +322,7 @@ def _pobtas_streaming( # --- Forward substitution --- for i in range(0, n_diag_blocks - 1): - if i + 1 < n_diag_blocks - 1: + if i < n_diag_blocks - 1: # stream next B block h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) @@ -333,6 +333,7 @@ def _pobtas_streaming( h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) + if i + 1 < n_diag_blocks - 1: # stream next diagonal block h2d_stream.wait_event(compute_current_B_events[(i + 1) % 2]) From b509c0ddc784138733d16d900b1db324cd7076f8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 13:40:16 +0000 Subject: [PATCH 068/518] debug changed to check n --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 50156d13..23549fad 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -322,7 +322,7 @@ def _pobtas_streaming( # --- Forward substitution --- for i in range(0, n_diag_blocks - 1): - if i < n_diag_blocks - 1: + if i + 1 < n_diag_blocks - 1: # stream next B block h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) From c0d0c330ed17cfe667c2c7e2dac436c7911cdfcc Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 13:40:58 +0000 Subject: [PATCH 069/518] consitentcy update --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 23549fad..6dd1bfdd 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -322,7 +322,7 @@ def _pobtas_streaming( # --- Forward substitution --- for i in range(0, n_diag_blocks - 1): - if i + 1 < n_diag_blocks - 1: + if i + 1 < n_diag_blocks: # stream next B block h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) From 6c45c15b198f07182e58423a22270ee56e23f69b Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 13:50:04 +0000 Subject: [PATCH 070/518] changed non partial part --- src/serinv/algs/pobtas.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 6dd1bfdd..ccb34f20 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -422,8 +422,8 @@ def _pobtas_streaming( # arrow tip block of the RHS. h2d_stream.wait_event(d2h_tip_events[n_diag_blocks % 2]) - L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream) - h2d_diagonal_events[0].record(stream=h2d_stream) + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream) + h2d_diagonal_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) h2d_arrow_events[0].record(stream=h2d_stream) @@ -432,7 +432,7 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_diagonal_events[0]) - B_d[(n_diag_blocks - 1) % 2] = (cu_la.solve_triangular(L_diagonal_blocks_d[0], B_d[(n_diag_blocks - 1) % 2], lower=True,)) + B_d[(n_diag_blocks - 1) % 2] = (cu_la.solve_triangular(L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], B_d[(n_diag_blocks - 1) % 2], lower=True,)) compute_partial_events[0].record(stream=compute_stream) d2h_stream.wait_event(compute_partial_events[0]) From ba9f28ead0b50feb26446bf1494526220d45f196 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 14:01:17 +0000 Subject: [PATCH 071/518] changed non partial block to match indexing --- src/serinv/algs/pobtas.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index ccb34f20..33979a92 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -425,24 +425,24 @@ def _pobtas_streaming( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream) h2d_diagonal_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) - L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) - h2d_arrow_events[0].record(stream=h2d_stream) + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) + h2d_arrow_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) with compute_stream: - compute_stream.wait_event(h2d_diagonal_events[0]) + compute_stream.wait_event(h2d_diagonal_events[(n_diag_blocks - 1) % 2]) B_d[(n_diag_blocks - 1) % 2] = (cu_la.solve_triangular(L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], B_d[(n_diag_blocks - 1) % 2], lower=True,)) compute_partial_events[0].record(stream=compute_stream) d2h_stream.wait_event(compute_partial_events[0]) - B_d[0].get(out=B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize], stream=d2h_stream, blocking=False,) + B_d[(n_diag_blocks - 1) % 2].get(out=B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize], stream=d2h_stream, blocking=False,) d2h_B_events[0].record(stream=d2h_stream) with compute_stream: - compute_stream.wait_event(h2d_arrow_events[0]) + compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) - B_arrow_tip_d -= (L_lower_arrow_blocks_d[0] @ B_d[(n_diag_blocks - 1) % 2]) + B_arrow_tip_d -= (L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2] @ B_d[(n_diag_blocks - 1) % 2]) compute_partial_events[1].record(stream=compute_stream) compute_stream.wait_event(compute_partial_events[1]) From 90d6a747dd4c6cd285c7f6aeb17945333181f1cc Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 20:40:34 +0000 Subject: [PATCH 072/518] first attempt at backward solve --- src/serinv/algs/pobtas.py | 157 +++++++++++++----- .../regular/tests_bta/test_pobtas.py | 1 + 2 files changed, 121 insertions(+), 37 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 33979a92..6da26720 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -241,13 +241,47 @@ def _pobtas_streaming( cp, cu_la = _get_module_from_str(module_str="cupy") + # Vars + diag_blocksize = L_diagonal_blocks.shape[1] + arrow_blocksize = L_lower_arrow_blocks.shape[1] + n_diag_blocks = L_diagonal_blocks.shape[0] + + # Streams + compute_stream = cp.cuda.Stream(non_blocking=True) + h2d_stream = cp.cuda.Stream(non_blocking=True) + d2h_stream = cp.cuda.Stream(non_blocking=True) + + + + # Device Buffers + # B Buffers + B_shape = B[-arrow_blocksize:] # block template + B_arrow_tip_d = cp.empty_like(B_shape) + + B_shape = B[0 : diag_blocksize] + B_d = cp.empty( + (2, *B_shape.shape), dtype=B_shape.dtype + ) + + + # L Buffers + L_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_lower_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_lower_arrow_blocks_d = cp.empty( + (2, *L_lower_arrow_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_arrow_tip_block_d = cp.empty_like(L_arrow_tip_block) + if trans == "N": - # Streams and events - compute_stream = cp.cuda.Stream(non_blocking=True) - h2d_stream = cp.cuda.Stream(non_blocking=True) - d2h_stream = cp.cuda.Stream(non_blocking=True) + # delete helper variable + del B_shape + # Events h2d_diagonal_events = [cp.cuda.Event(), cp.cuda.Event()] h2d_lower_diagonal_events = [cp.cuda.Event(), cp.cuda.Event()] h2d_arrow_events = [cp.cuda.Event(), cp.cuda.Event()] @@ -262,33 +296,6 @@ def _pobtas_streaming( compute_partial_events = [cp.cuda.Event(), cp.cuda.Event()] - # Vars - diag_blocksize = L_diagonal_blocks.shape[1] - arrow_blocksize = L_lower_arrow_blocks.shape[1] - n_diag_blocks = L_diagonal_blocks.shape[0] - - # Device Buffers - # B Buffers - B_shape = B[0 : diag_blocksize] # block template - B_d = cp.empty( - (2, *B_shape.shape), dtype=B_shape.dtype - ) - B_shape = B[-arrow_blocksize:] - B_arrow_tip_d = cp.empty_like(B_shape) - del B_shape - - # L Buffers - L_diagonal_blocks_d = cp.empty( - (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype - ) - L_lower_diagonal_blocks_d = cp.empty( - (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype - ) - L_lower_arrow_blocks_d = cp.empty( - (2, *L_lower_arrow_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype - ) - L_arrow_tip_block_d = cp.empty_like(L_arrow_tip_block) - # Forward Pass # --- C: events + transfers --- compute_current_B_events[1].record(stream=compute_stream) @@ -453,16 +460,92 @@ def _pobtas_streaming( B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) elif trans == "T" or trans == "C": + # Buffers + B_previous_d = cp.empty_like(B_shape) + del B_shape + + # Events + compute_B_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_events = [cp.cuda.Event(), cp.cuda.Event()] + d2h_events = [cp.cuda.Event(), cp.cuda.Event()] + + # Forward Pass + # --- C: events + transfers --- + + B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) + L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:], stream=h2d_stream) + B_d[(n_diag_blocks - 1) % 2].set(arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=h2d_stream) + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[-1], stream=h2d_stream) + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) + + h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) + if n_diag_blocks > 1: + B_d[n_diag_blocks % 2].set( + arr=B[-arrow_blocksize - 2 * diag_blocksize : -arrow_blocksize - diag_blocksize], + stream=h2d_stream + ) + # ----- Backward substitution ----- if not partial: # X_{ndb+1} = L_{ndb+1,ndb+1}^{-T} (Y_{ndb+1}) - raise NotImplementedError( - "T and C not yet implemented." - ) - # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) + with compute_stream: + compute_stream.wait_event(h2d_events[n_diag_blocks % 2]) + B_arrow_tip_d = cu_la.solve_triangular( + L_arrow_tip_block_d, + B_arrow_tip_d, + lower=True, + trans="C", + ) - # for i in range(n_diag_blocks -2, -1, -1): - # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + B_d[(n_diag_blocks -1) % 2] = ( + cu_la.solve_triangular( + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2] + - L_lower_arrow_blocks[(n_diag_blocks - 1) % 2].conj().T @ B_arrow_tip_d, + lower=True, + trans="C", + ) + ) + + compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) + + d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) + B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + B_d.get(out=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=d2h_stream, blocking=False,) + d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) + + for i in range(n_diag_blocks - 2, -1, -1): + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + + if i > 0: + with h2d_stream: + h2d_stream.wait_event(d2h_events[(i + 1) % 2]) + + B_previous_d = B_d[(i + 1) % 2] + B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize]) + L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1]) + L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1]) + L_lower_arrow_blocks_d[(i - 1) % 2].set(arr=L_lower_arrow_blocks[i - 1]) + + h2d_events[(i + 1) % 2].record(stream=h2d_stream) + + with compute_stream: + compute_stream.wait_event(h2d_events[i % 2]) + + B_d[i % 2] = cu_la.solve_triangular( + L_diagonal_blocks_d[i % 2], + B_d[i % 2] + - L_lower_diagonal_blocks_d[i % 2].conj().T + @ B_previous_d + - L_lower_arrow_blocks_d[i % 2].conj().T @ B_arrow_tip_d, + lower=True, + trans="C", + ) + + compute_B_events[i % 2].record(compute_stream) + + B_d[i % 2].get(out=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=d2h_stream, blocking=False) + d2h_events[i % 2].record(stream=d2h_stream) else: raise ValueError(f"Invalid transpose argument: {trans}.") diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index cbb7aaed..4c51e79c 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -100,6 +100,7 @@ def test_pobtas( A_arrow_tip_block, B, trans="C", + device_streaming=True if array_type == "streaming" else False, ) assert xp.allclose(B, X_ref) From 14a9215f400a83a1d69642c1c0fd901bbf908ed3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 20:48:25 +0000 Subject: [PATCH 073/518] fixed typo --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 6da26720..9ce95372 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -501,7 +501,7 @@ def _pobtas_streaming( cu_la.solve_triangular( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], B_d[(n_diag_blocks - 1) % 2] - - L_lower_arrow_blocks[(n_diag_blocks - 1) % 2].conj().T @ B_arrow_tip_d, + - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].conj().T @ B_arrow_tip_d, lower=True, trans="C", ) From 3ce3b0aef4cebe1d6c9368ee03bbf79e009a5855 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 20:50:58 +0000 Subject: [PATCH 074/518] another typo --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 9ce95372..fd811ef5 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -497,7 +497,7 @@ def _pobtas_streaming( trans="C", ) - B_d[(n_diag_blocks -1) % 2] = ( + B_d[(n_diag_blocks - 1) % 2] = ( cu_la.solve_triangular( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], B_d[(n_diag_blocks - 1) % 2] @@ -511,7 +511,7 @@ def _pobtas_streaming( d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - B_d.get(out=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=d2h_stream, blocking=False,) + B_d[(n_diag_blocks - 1) % 2].get(out=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=d2h_stream, blocking=False,) d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) for i in range(n_diag_blocks - 2, -1, -1): From 99c2e3d784bde472f1cf8648f7c59c90bac0e729 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 21:03:42 +0000 Subject: [PATCH 075/518] insert parenthesis --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index fd811ef5..94dfccbb 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -481,7 +481,7 @@ def _pobtas_streaming( h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) if n_diag_blocks > 1: B_d[n_diag_blocks % 2].set( - arr=B[-arrow_blocksize - 2 * diag_blocksize : -arrow_blocksize - diag_blocksize], + arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], stream=h2d_stream ) From af5f83d8310f69c888c6af5016ca8ccc1ee6fec7 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 21:13:32 +0000 Subject: [PATCH 076/518] insert debug staetments --- src/serinv/algs/pobtas.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 94dfccbb..3b8e38e3 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -520,14 +520,15 @@ def _pobtas_streaming( if i > 0: with h2d_stream: h2d_stream.wait_event(d2h_events[(i + 1) % 2]) - + print(B_d[(i + 1) % 2]) B_previous_d = B_d[(i + 1) % 2] + print(B_previous_d) B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize]) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1]) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1]) L_lower_arrow_blocks_d[(i - 1) % 2].set(arr=L_lower_arrow_blocks[i - 1]) - h2d_events[(i + 1) % 2].record(stream=h2d_stream) + h2d_events[(i - 1) % 2].record(stream=h2d_stream) with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) From 799af09d084964bbae435960dd727346b771b043 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 21:15:27 +0000 Subject: [PATCH 077/518] more debug --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 3b8e38e3..af83d80c 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -516,7 +516,7 @@ def _pobtas_streaming( for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} - + print(B) if i > 0: with h2d_stream: h2d_stream.wait_event(d2h_events[(i + 1) % 2]) From 5cc569eac5e1937181dfde3a3d5df8dc9a803873 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 21:27:42 +0000 Subject: [PATCH 078/518] added missing streaming --- src/serinv/algs/pobtas.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index af83d80c..78a2c541 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -484,6 +484,8 @@ def _pobtas_streaming( arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], stream=h2d_stream ) + L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) + L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) # ----- Backward substitution ----- if not partial: @@ -516,13 +518,10 @@ def _pobtas_streaming( for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} - print(B) if i > 0: with h2d_stream: h2d_stream.wait_event(d2h_events[(i + 1) % 2]) - print(B_d[(i + 1) % 2]) B_previous_d = B_d[(i + 1) % 2] - print(B_previous_d) B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize]) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1]) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1]) From fa95f1614daf34600dfa8ccf88945abbc1ec97eb Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 21:31:46 +0000 Subject: [PATCH 079/518] added debug statements --- src/serinv/algs/pobtas.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 78a2c541..e58854e6 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -484,8 +484,11 @@ def _pobtas_streaming( arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], stream=h2d_stream ) + print(B) L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) + print(L_diagonal_blocks_d[n_diag_blocks % 2]) + print(L_lower_arrow_blocks_d[n_diag_blocks % 2]) # ----- Backward substitution ----- if not partial: From b394b91a4101c52185bed3b9af698541f13f45cb Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 21:32:38 +0000 Subject: [PATCH 080/518] changed debug --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index e58854e6..bd583559 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -484,7 +484,7 @@ def _pobtas_streaming( arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], stream=h2d_stream ) - print(B) + print(L_diagonal_blocks) L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) print(L_diagonal_blocks_d[n_diag_blocks % 2]) From 28220ad9571abe9214252f745310a8d2ca93c527 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 21:51:35 +0000 Subject: [PATCH 081/518] new debug statements --- src/serinv/algs/pobtas.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index bd583559..e79cf865 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -484,11 +484,10 @@ def _pobtas_streaming( arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], stream=h2d_stream ) - print(L_diagonal_blocks) + print(B) + print(B_d[n_diag_blocks % 2]) L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) - print(L_diagonal_blocks_d[n_diag_blocks % 2]) - print(L_lower_arrow_blocks_d[n_diag_blocks % 2]) # ----- Backward substitution ----- if not partial: From 7e19033b53ac506461a7b6974f3b4b5759769d1c Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 22:00:52 +0000 Subject: [PATCH 082/518] new debugs --- src/serinv/algs/pobtas.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index e79cf865..f4ae40a6 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -484,8 +484,6 @@ def _pobtas_streaming( arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], stream=h2d_stream ) - print(B) - print(B_d[n_diag_blocks % 2]) L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) @@ -533,7 +531,7 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) - + print(B_previous_d) B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] From 813638169f6cd7d42bdd31c7b0a349ca5a1ad8f9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 22:04:18 +0000 Subject: [PATCH 083/518] changed stream timing --- src/serinv/algs/pobtas.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index f4ae40a6..52da1e2a 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -518,10 +518,13 @@ def _pobtas_streaming( for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + with compute_stream: + B_previous_d = B_d[(i + 1) % 2] + if i > 0: with h2d_stream: h2d_stream.wait_event(d2h_events[(i + 1) % 2]) - B_previous_d = B_d[(i + 1) % 2] + B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize]) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1]) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1]) From cfa8307aa92e2d7dc5475a3cd461c4e9f91fe82f Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 22:08:24 +0000 Subject: [PATCH 084/518] adjusted stram timing --- src/serinv/algs/pobtas.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 52da1e2a..65aa5e8f 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -466,6 +466,7 @@ def _pobtas_streaming( # Events compute_B_events = [cp.cuda.Event(), cp.cuda.Event()] + previous_B_event = cp.cuda.Event() h2d_events = [cp.cuda.Event(), cp.cuda.Event()] d2h_events = [cp.cuda.Event(), cp.cuda.Event()] @@ -519,18 +520,20 @@ def _pobtas_streaming( for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} with compute_stream: + compute_stream.wait_event(d2h_events[(i + 1) % 2]) B_previous_d = B_d[(i + 1) % 2] + previous_B_event.record(stream=compute_stream) if i > 0: - with h2d_stream: - h2d_stream.wait_event(d2h_events[(i + 1) % 2]) + + h2d_stream.wait_event(previous_B_event) - B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize]) - L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1]) - L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1]) - L_lower_arrow_blocks_d[(i - 1) % 2].set(arr=L_lower_arrow_blocks[i - 1]) + B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) + L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) + L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) + L_lower_arrow_blocks_d[(i - 1) % 2].set(arr=L_lower_arrow_blocks[i - 1], stream=h2d_stream) - h2d_events[(i - 1) % 2].record(stream=h2d_stream) + h2d_events[(i - 1) % 2].record(stream=h2d_stream) with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) From 304b3687c9d7d51498be40e517ac696b1c1f7997 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 22:26:50 +0000 Subject: [PATCH 085/518] changed event recording --- src/serinv/algs/pobtas.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 65aa5e8f..f1c5e39d 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -479,7 +479,7 @@ def _pobtas_streaming( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[-1], stream=h2d_stream) L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) - h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) if n_diag_blocks > 1: B_d[n_diag_blocks % 2].set( arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], @@ -488,11 +488,15 @@ def _pobtas_streaming( L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) + h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) + + + # ----- Backward substitution ----- if not partial: # X_{ndb+1} = L_{ndb+1,ndb+1}^{-T} (Y_{ndb+1}) with compute_stream: - compute_stream.wait_event(h2d_events[n_diag_blocks % 2]) + compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) B_arrow_tip_d = cu_la.solve_triangular( L_arrow_tip_block_d, B_arrow_tip_d, From dd82d4bfc537040f3a72c9a4a5dd4258c5d64f53 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 22:31:30 +0000 Subject: [PATCH 086/518] more debug --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index f1c5e39d..3845b3c2 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -531,7 +531,7 @@ def _pobtas_streaming( if i > 0: h2d_stream.wait_event(previous_B_event) - + print("ping") B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) @@ -541,7 +541,7 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) - print(B_previous_d) + print("pong") B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] From 5370943e1cd4e2ae0e4c69c523da181b78ef90e7 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 22:39:54 +0000 Subject: [PATCH 087/518] insert first compare debug --- src/serinv/algs/pobtas.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 3845b3c2..0443cb1a 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -531,7 +531,7 @@ def _pobtas_streaming( if i > 0: h2d_stream.wait_event(previous_B_event) - print("ping") + B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) @@ -541,7 +541,9 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) - print("pong") + print(i) + print(L_diagonal_blocks) + print(L_diagonal_blocks_d[i % 2]) B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] From 38289105b743b5dd72e51e687845aa03c2145405 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 22:42:40 +0000 Subject: [PATCH 088/518] second debug compare --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 0443cb1a..3a4aa4f0 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -542,8 +542,8 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) print(i) - print(L_diagonal_blocks) - print(L_diagonal_blocks_d[i % 2]) + print(L_lower_diagonal_blocks) + print(L_lower_diagonal_blocks_d[i % 2]) B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] From e36f83d118b8632e7f30034f9d60b59f04b34371 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 22:52:27 +0000 Subject: [PATCH 089/518] inserted lower diagonal blocks streaming --- src/serinv/algs/pobtas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 3a4aa4f0..8af88598 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -487,6 +487,7 @@ def _pobtas_streaming( ) L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) + L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) From f71a31511881fe086d9303e318642acb239b1f4a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 22:54:55 +0000 Subject: [PATCH 090/518] debug compare 3 --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 8af88598..3a18661f 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -543,8 +543,8 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) print(i) - print(L_lower_diagonal_blocks) - print(L_lower_diagonal_blocks_d[i % 2]) + print(L_lower_arrow_blocks) + print(L_lower_arrow_blocks_d[i % 2]) B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] From a2ddd30aed7d2bcf2a805f5cbecb8edaa54d5323 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 22:59:00 +0000 Subject: [PATCH 091/518] compare 4 --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 3a18661f..4055dc58 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -543,8 +543,8 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) print(i) - print(L_lower_arrow_blocks) - print(L_lower_arrow_blocks_d[i % 2]) + print(B) + print(B_previous_d) B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] From 5d22f9476dba71913b661c886e5306a60b9f49c9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 23:06:56 +0000 Subject: [PATCH 092/518] changed location of B_previous --- src/serinv/algs/pobtas.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 4055dc58..3ab6031e 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -524,19 +524,16 @@ def _pobtas_streaming( for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} - with compute_stream: - compute_stream.wait_event(d2h_events[(i + 1) % 2]) - B_previous_d = B_d[(i + 1) % 2] - previous_B_event.record(stream=compute_stream) if i > 0: - h2d_stream.wait_event(previous_B_event) + h2d_stream.wait_event(d2h_events[(i + 1) % 2]) B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_arrow_blocks_d[(i - 1) % 2].set(arr=L_lower_arrow_blocks[i - 1], stream=h2d_stream) + B_previous_d.set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream) h2d_events[(i - 1) % 2].record(stream=h2d_stream) From e59fd543233342c69a791b9fa59cc661259fb2f8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 23:10:50 +0000 Subject: [PATCH 093/518] added previous B setup --- src/serinv/algs/pobtas.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 3ab6031e..67f52ba0 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -480,16 +480,7 @@ def _pobtas_streaming( L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) - if n_diag_blocks > 1: - B_d[n_diag_blocks % 2].set( - arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], - stream=h2d_stream - ) - L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) - L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) - L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) - - h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) + @@ -521,6 +512,21 @@ def _pobtas_streaming( B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) B_d[(n_diag_blocks - 1) % 2].get(out=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=d2h_stream, blocking=False,) d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) + previous_B_event.record(stream=d2h_stream) + + if n_diag_blocks > 1: + + B_d[n_diag_blocks % 2].set( + arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], + stream=h2d_stream + ) + L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) + L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) + L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) + h2d_stream.wait_event(previous_B_event) + B_previous_d.set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream) + + h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} From 50e728dba3f77780dfacd656ce32f521fe673526 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 23:12:46 +0000 Subject: [PATCH 094/518] fixed indexing --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 67f52ba0..2ae221da 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -524,7 +524,7 @@ def _pobtas_streaming( L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) h2d_stream.wait_event(previous_B_event) - B_previous_d.set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream) + B_previous_d.set(arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=h2d_stream) h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) From 31233d087730b6939d55f9c8528bf68d2fc5f5a4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 23:17:31 +0000 Subject: [PATCH 095/518] moved brevious b from if --- src/serinv/algs/pobtas.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 2ae221da..69b3415e 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -524,9 +524,10 @@ def _pobtas_streaming( L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) h2d_stream.wait_event(previous_B_event) - B_previous_d.set(arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=h2d_stream) + + B_previous_d.set(arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=h2d_stream) - h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) + h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} From 2133c4d43380c73601b7fd93b4e48c0a4e3b9a94 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 23:19:20 +0000 Subject: [PATCH 096/518] moved previous b from correct if --- src/serinv/algs/pobtas.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 69b3415e..e21ac8e4 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -524,10 +524,9 @@ def _pobtas_streaming( L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) h2d_stream.wait_event(previous_B_event) - - B_previous_d.set(arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=h2d_stream) + B_previous_d.set(arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=h2d_stream) - h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) + h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} @@ -540,9 +539,9 @@ def _pobtas_streaming( L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_arrow_blocks_d[(i - 1) % 2].set(arr=L_lower_arrow_blocks[i - 1], stream=h2d_stream) - B_previous_d.set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream) + B_previous_d.set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream) - h2d_events[(i - 1) % 2].record(stream=h2d_stream) + h2d_events[(i - 1) % 2].record(stream=h2d_stream) with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) From 533af2a74560e93d586390a7c0055bb115c09e91 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 23:21:53 +0000 Subject: [PATCH 097/518] removed debug statements --- src/serinv/algs/pobtas.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index e21ac8e4..18420c83 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -539,15 +539,13 @@ def _pobtas_streaming( L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_arrow_blocks_d[(i - 1) % 2].set(arr=L_lower_arrow_blocks[i - 1], stream=h2d_stream) + B_previous_d.set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream) h2d_events[(i - 1) % 2].record(stream=h2d_stream) with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) - print(i) - print(B) - print(B_previous_d) B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] From 2db72730a7c69616e05aaac022e93f500ef57c4d Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 23:22:50 +0000 Subject: [PATCH 098/518] moved a wait event --- src/serinv/algs/pobtas.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 18420c83..5c0a8809 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -531,10 +531,9 @@ def _pobtas_streaming( for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} - if i > 0: - - h2d_stream.wait_event(d2h_events[(i + 1) % 2]) + h2d_stream.wait_event(d2h_events[(i + 1) % 2]) + if i > 0: B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) From 45b717937ff88316ed3b14723d38be5b09f17f87 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 23:27:24 +0000 Subject: [PATCH 099/518] delayed d2h stream --- src/serinv/algs/pobtas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 5c0a8809..aa57a128 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -557,6 +557,7 @@ def _pobtas_streaming( compute_B_events[i % 2].record(compute_stream) + d2h_stream.wait_event(compute_B_events[i % 2]) B_d[i % 2].get(out=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) From ab395e468420f89ad3bc8b36f8a00dfa796f7546 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 23:30:42 +0000 Subject: [PATCH 100/518] adjusted stream timing --- src/serinv/algs/pobtas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index aa57a128..5efe534e 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -539,6 +539,7 @@ def _pobtas_streaming( L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_arrow_blocks_d[(i - 1) % 2].set(arr=L_lower_arrow_blocks[i - 1], stream=h2d_stream) + h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) B_previous_d.set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream) h2d_events[(i - 1) % 2].record(stream=h2d_stream) From d4f0128710c5a8c53c9a5f7277f995591f7fed60 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 29 Apr 2025 23:34:22 +0000 Subject: [PATCH 101/518] even more adjusted timing --- src/serinv/algs/pobtas.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 5efe534e..5d810cfd 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -538,14 +538,16 @@ def _pobtas_streaming( L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_arrow_blocks_d[(i - 1) % 2].set(arr=L_lower_arrow_blocks[i - 1], stream=h2d_stream) + h2d_events[(i - 1) % 2].record(stream=h2d_stream) h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) B_previous_d.set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream) - - h2d_events[(i - 1) % 2].record(stream=h2d_stream) + previous_B_event.record(stream=d2h_stream) + with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) + h2d_stream.wait_event(previous_B_event) B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] From ba2d6acb1b4e92937d8d93aa87077a178d6b17d2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 12:09:15 +0000 Subject: [PATCH 102/518] changed streaming order --- src/serinv/algs/pobtas.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 5d810cfd..a40a476b 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -530,24 +530,23 @@ def _pobtas_streaming( for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} - + with compute_stream: + compute_stream.wait_event(compute_B_events[(i - 1) % 2]) + compute_stream.wait_event(d2h_events[(i - 1) % 2]) + B_previous_d = B_d[(i - 1) % 2] + previous_B_event.record(stream=compute_stream) - h2d_stream.wait_event(d2h_events[(i + 1) % 2]) + if i > 0: + h2d_stream.wait_event(previous_B_event) B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_arrow_blocks_d[(i - 1) % 2].set(arr=L_lower_arrow_blocks[i - 1], stream=h2d_stream) h2d_events[(i - 1) % 2].record(stream=h2d_stream) - - h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) - B_previous_d.set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream) - previous_B_event.record(stream=d2h_stream) - with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) - h2d_stream.wait_event(previous_B_event) B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] @@ -560,9 +559,11 @@ def _pobtas_streaming( compute_B_events[i % 2].record(compute_stream) - d2h_stream.wait_event(compute_B_events[i % 2]) - B_d[i % 2].get(out=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=d2h_stream, blocking=False) + d2h_stream.wait_event(previous_B_event) + B_previous_d.get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) + + B_previous_d.get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) else: raise ValueError(f"Invalid transpose argument: {trans}.") From 5efb03a00566029e007c8844f2f419ce63924f01 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 12:13:52 +0000 Subject: [PATCH 103/518] removed strange get --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index a40a476b..338c0a16 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -563,7 +563,7 @@ def _pobtas_streaming( B_previous_d.get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) - B_previous_d.get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) + # B_previous_d.get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) else: raise ValueError(f"Invalid transpose argument: {trans}.") From 66d2f6bfa4baef93c3b6fe698a351e324fb9760b Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 12:21:01 +0000 Subject: [PATCH 104/518] insert debug staetments --- src/serinv/algs/pobtas.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 338c0a16..df4534e3 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -547,6 +547,8 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) + print(B_d) + print(B_previous_d) B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] From cd2b9c7e2ec1cb78389c9e349fb96e73bc7a1560 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 12:22:05 +0000 Subject: [PATCH 105/518] changed debug --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index df4534e3..640aeade 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -547,7 +547,7 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) - print(B_d) + print(B) print(B_previous_d) B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], From 9db7858e7777088ccfa516093f3d136c5204ea5e Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 12:31:31 +0000 Subject: [PATCH 106/518] changed last get --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 640aeade..32fe28f4 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -565,7 +565,7 @@ def _pobtas_streaming( B_previous_d.get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) - # B_previous_d.get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) + B_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) else: raise ValueError(f"Invalid transpose argument: {trans}.") From b37207be286536a66d9c8fe5c77ebc594ec820c3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 13:44:01 +0000 Subject: [PATCH 107/518] more debugging --- src/serinv/algs/pobtas.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 32fe28f4..319d2106 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -533,7 +533,9 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(compute_B_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) + print(B_previous_d) B_previous_d = B_d[(i - 1) % 2] + print(B_previous_d) previous_B_event.record(stream=compute_stream) From ad6d37520c9f9c51c503f2aea6c12efa6ae19d59 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 15:28:53 +0000 Subject: [PATCH 108/518] changed B events --- src/serinv/algs/pobtas.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 319d2106..fef20ade 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -466,7 +466,7 @@ def _pobtas_streaming( # Events compute_B_events = [cp.cuda.Event(), cp.cuda.Event()] - previous_B_event = cp.cuda.Event() + previous_B_events = [cp.cuda.Event(), cp.cuda.Event()] h2d_events = [cp.cuda.Event(), cp.cuda.Event()] d2h_events = [cp.cuda.Event(), cp.cuda.Event()] @@ -512,7 +512,7 @@ def _pobtas_streaming( B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) B_d[(n_diag_blocks - 1) % 2].get(out=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=d2h_stream, blocking=False,) d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) - previous_B_event.record(stream=d2h_stream) + previous_B_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) if n_diag_blocks > 1: @@ -523,7 +523,7 @@ def _pobtas_streaming( L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) - h2d_stream.wait_event(previous_B_event) + h2d_stream.wait_event(previous_B_events[(n_diag_blocks - 1) % 2]) B_previous_d.set(arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=h2d_stream) h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) @@ -536,11 +536,11 @@ def _pobtas_streaming( print(B_previous_d) B_previous_d = B_d[(i - 1) % 2] print(B_previous_d) - previous_B_event.record(stream=compute_stream) + previous_B_events[i % 2].record(stream=compute_stream) if i > 0: - h2d_stream.wait_event(previous_B_event) + h2d_stream.wait_event(previous_B_events[i % 2]) B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) @@ -563,7 +563,7 @@ def _pobtas_streaming( compute_B_events[i % 2].record(compute_stream) - d2h_stream.wait_event(previous_B_event) + d2h_stream.wait_event(previous_B_events[i % 2]) B_previous_d.get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) From e83d0b888b12b2520afc5f31eb00b64fba5f5fbb Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 15:34:54 +0000 Subject: [PATCH 109/518] print B_d --- src/serinv/algs/pobtas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index fef20ade..22f5b0a0 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -550,6 +550,7 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) print(B) + print(B_d) print(B_previous_d) B_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], From 6fd9ff1c594b66f09725a1bcccd413097f4e0254 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 15:36:00 +0000 Subject: [PATCH 110/518] insert seperator print --- src/serinv/algs/pobtas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 22f5b0a0..b7eb8af0 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -530,6 +530,7 @@ def _pobtas_streaming( for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + print("---") with compute_stream: compute_stream.wait_event(compute_B_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) From 3bc6718e741b10ceddd4f0657ae869d42f4b9d61 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 15:43:58 +0000 Subject: [PATCH 111/518] changed location of previous B event --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index b7eb8af0..2913037b 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -512,7 +512,7 @@ def _pobtas_streaming( B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) B_d[(n_diag_blocks - 1) % 2].get(out=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=d2h_stream, blocking=False,) d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) - previous_B_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) + if n_diag_blocks > 1: @@ -525,7 +525,7 @@ def _pobtas_streaming( L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[(n_diag_blocks - 1) % 2]) B_previous_d.set(arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=h2d_stream) - + previous_B_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) for i in range(n_diag_blocks - 2, -1, -1): From ab3fd2a53733de7f7eda9e05303f2e2522329cc2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 15:46:56 +0000 Subject: [PATCH 112/518] changed order of compute stream --- src/serinv/algs/pobtas.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 2913037b..b310e36f 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -531,13 +531,7 @@ def _pobtas_streaming( for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} print("---") - with compute_stream: - compute_stream.wait_event(compute_B_events[(i - 1) % 2]) - compute_stream.wait_event(d2h_events[(i - 1) % 2]) - print(B_previous_d) - B_previous_d = B_d[(i - 1) % 2] - print(B_previous_d) - previous_B_events[i % 2].record(stream=compute_stream) + if i > 0: @@ -569,6 +563,14 @@ def _pobtas_streaming( B_previous_d.get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) + with compute_stream: + compute_stream.wait_event(compute_B_events[i % 2]) + compute_stream.wait_event(d2h_events[i % 2]) + print(B_previous_d) + B_previous_d = B_d[(i - 1) % 2] + print(B_previous_d) + previous_B_events[i % 2].record(stream=compute_stream) + B_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) else: From 5e72b089de522fcce97b92651ab94eaf24d8d0d4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 15:50:26 +0000 Subject: [PATCH 113/518] switched chose previous B --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index b310e36f..7b4b7a9b 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -567,7 +567,7 @@ def _pobtas_streaming( compute_stream.wait_event(compute_B_events[i % 2]) compute_stream.wait_event(d2h_events[i % 2]) print(B_previous_d) - B_previous_d = B_d[(i - 1) % 2] + B_previous_d = B_d[i % 2] print(B_previous_d) previous_B_events[i % 2].record(stream=compute_stream) From 7155c8a29f36f22892d7fd6dea5b14ccdcd8aa6b Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 16:02:53 +0000 Subject: [PATCH 114/518] changed wait event --- src/serinv/algs/pobtas.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 7b4b7a9b..442e9858 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -535,7 +535,7 @@ def _pobtas_streaming( if i > 0: - h2d_stream.wait_event(previous_B_events[i % 2]) + h2d_stream.wait_event(previous_B_events[(i -1 ) % 2]) B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) @@ -564,7 +564,6 @@ def _pobtas_streaming( d2h_events[i % 2].record(stream=d2h_stream) with compute_stream: - compute_stream.wait_event(compute_B_events[i % 2]) compute_stream.wait_event(d2h_events[i % 2]) print(B_previous_d) B_previous_d = B_d[i % 2] From 09f31c3b6e327d463e41f051cac4df3f06275ca4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 16:03:13 +0000 Subject: [PATCH 115/518] changed another wait event --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 442e9858..6dc4d695 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -559,7 +559,7 @@ def _pobtas_streaming( compute_B_events[i % 2].record(compute_stream) - d2h_stream.wait_event(previous_B_events[i % 2]) + d2h_stream.wait_event(previous_B_events[(i - 1) % 2]) B_previous_d.get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) From d75b7ff7b49909b295a5bd74f29860e7d9a5418a Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 16:57:40 +0000 Subject: [PATCH 116/518] changed stream pattern --- src/serinv/algs/pobtas.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 6dc4d695..f3095387 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -535,7 +535,7 @@ def _pobtas_streaming( if i > 0: - h2d_stream.wait_event(previous_B_events[(i -1 ) % 2]) + h2d_stream.wait_event(previous_B_events[(i - 1 ) % 2]) B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) @@ -559,7 +559,8 @@ def _pobtas_streaming( compute_B_events[i % 2].record(compute_stream) - d2h_stream.wait_event(previous_B_events[(i - 1) % 2]) + # d2h_stream.wait_event(previous_B_events[(i - 1) % 2]) + d2h_stream.wait_event(h2d_events[i % 2]) B_previous_d.get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) From 87fb54b8971afb804b893626e92d8bd943cca085 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 18:14:08 +0000 Subject: [PATCH 117/518] changed previous B --- src/serinv/algs/pobtas.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index f3095387..99086fe5 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -461,7 +461,9 @@ def _pobtas_streaming( elif trans == "T" or trans == "C": # Buffers - B_previous_d = cp.empty_like(B_shape) + B_previous_d = cp.empty( + (2, *B_shape.shape), dtype=B_shape.dtype + ) del B_shape # Events @@ -524,7 +526,7 @@ def _pobtas_streaming( L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[(n_diag_blocks - 1) % 2]) - B_previous_d.set(arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=h2d_stream) + B_previous_d[(n_diag_blocks - 1) % 2].set(arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=h2d_stream) previous_B_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) @@ -535,7 +537,7 @@ def _pobtas_streaming( if i > 0: - h2d_stream.wait_event(previous_B_events[(i - 1 ) % 2]) + h2d_stream.wait_event(compute_B_events[(i - 1 ) % 2]) B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) @@ -544,14 +546,11 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[i % 2]) - print(B) - print(B_d) - print(B_previous_d) - B_d[i % 2] = cu_la.solve_triangular( + B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] - L_lower_diagonal_blocks_d[i % 2].conj().T - @ B_previous_d + @ B_previous_d[(i - 1) % 2] - L_lower_arrow_blocks_d[i % 2].conj().T @ B_arrow_tip_d, lower=True, trans="C", @@ -559,19 +558,11 @@ def _pobtas_streaming( compute_B_events[i % 2].record(compute_stream) - # d2h_stream.wait_event(previous_B_events[(i - 1) % 2]) - d2h_stream.wait_event(h2d_events[i % 2]) - B_previous_d.get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) + d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) + B_previous_d[(i - 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) - with compute_stream: - compute_stream.wait_event(d2h_events[i % 2]) - print(B_previous_d) - B_previous_d = B_d[i % 2] - print(B_previous_d) - previous_B_events[i % 2].record(stream=compute_stream) - - B_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) + B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) else: raise ValueError(f"Invalid transpose argument: {trans}.") From 464ca75d3e830f27ff8dd2935d3936b7f8e422cd Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 18:15:48 +0000 Subject: [PATCH 118/518] removed last B get --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 99086fe5..bee0774d 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -562,7 +562,7 @@ def _pobtas_streaming( B_previous_d[(i - 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) - B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) + #B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) else: raise ValueError(f"Invalid transpose argument: {trans}.") From ae2e2699803b34ca32edbfc75a1ef974934e02f6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 19:02:19 +0000 Subject: [PATCH 119/518] changed indexing --- src/serinv/algs/pobtas.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index bee0774d..c7058bd6 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -527,25 +527,21 @@ def _pobtas_streaming( L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[(n_diag_blocks - 1) % 2]) B_previous_d[(n_diag_blocks - 1) % 2].set(arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=h2d_stream) - previous_B_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) - h2d_events[n_diag_blocks % 2].record(stream=h2d_stream) + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} print("---") - - - if i > 0: - h2d_stream.wait_event(compute_B_events[(i - 1 ) % 2]) + h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_arrow_blocks_d[(i - 1) % 2].set(arr=L_lower_arrow_blocks[i - 1], stream=h2d_stream) - h2d_events[(i - 1) % 2].record(stream=h2d_stream) + h2d_events[i % 2].record(stream=h2d_stream) with compute_stream: - compute_stream.wait_event(h2d_events[i % 2]) + compute_stream.wait_event(h2d_events[(i - 1) % 2]) B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] From 729be575b9cf91ab441d098c9575ef524461bfbe Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 19:05:40 +0000 Subject: [PATCH 120/518] changed streaming a bit --- src/serinv/algs/pobtas.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index c7058bd6..60925305 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -542,11 +542,12 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[(i - 1) % 2]) + compute_stream.wait_event(d2h_events[(i - 1) % 2]) B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] - L_lower_diagonal_blocks_d[i % 2].conj().T - @ B_previous_d[(i - 1) % 2] + @ B_previous_d[(i + 1) % 2] - L_lower_arrow_blocks_d[i % 2].conj().T @ B_arrow_tip_d, lower=True, trans="C", @@ -555,7 +556,7 @@ def _pobtas_streaming( compute_B_events[i % 2].record(compute_stream) d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) - B_previous_d[(i - 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) + B_previous_d[(i + 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) #B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) From f801076781c680256e8c4dd327ad3ffcba0e127a Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 19:06:33 +0000 Subject: [PATCH 121/518] insert debug --- src/serinv/algs/pobtas.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 60925305..6d47f232 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -543,6 +543,8 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) + print(B_d) + print(B_previous_d) B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] From 44d858251b8e963a8d95adebe2e5870c576d2909 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 19:11:41 +0000 Subject: [PATCH 122/518] more debug --- src/serinv/algs/pobtas.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 6d47f232..0aed83c2 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -534,6 +534,7 @@ def _pobtas_streaming( print("---") if i > 0: h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) + print("h2d") B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) @@ -543,6 +544,7 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) + print("compute") print(B_d) print(B_previous_d) B_previous_d[i % 2] = cu_la.solve_triangular( @@ -558,6 +560,7 @@ def _pobtas_streaming( compute_B_events[i % 2].record(compute_stream) d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) + print("d2h") B_previous_d[(i + 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) From 0c816b0438faed7c507cb0b49db7da0c5a01e464 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 19:39:28 +0000 Subject: [PATCH 123/518] inser print B --- src/serinv/algs/pobtas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 0aed83c2..8d77b73c 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -545,6 +545,7 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) print("compute") + print(B) print(B_d) print(B_previous_d) B_previous_d[i % 2] = cu_la.solve_triangular( From b0a6473952de23f82f6d3f09fa3457b439b06c07 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 19:40:23 +0000 Subject: [PATCH 124/518] another print B --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 8d77b73c..79b832d3 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -566,7 +566,7 @@ def _pobtas_streaming( d2h_events[i % 2].record(stream=d2h_stream) #B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) - + print(B) else: raise ValueError(f"Invalid transpose argument: {trans}.") From a73f8d20361c6bbb82cbfafda62da52413c2fbec Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 19:44:46 +0000 Subject: [PATCH 125/518] print xref --- tests/tests_algs/regular/tests_bta/test_pobtas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index 4c51e79c..0824e72b 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -102,5 +102,6 @@ def test_pobtas( trans="C", device_streaming=True if array_type == "streaming" else False, ) + print(X_ref) assert xp.allclose(B, X_ref) From 39138d09b1b22bcc18d60db13b0bc7c1252a5de3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 19:45:28 +0000 Subject: [PATCH 126/518] more debug --- tests/tests_algs/regular/tests_bta/test_pobtas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index 0824e72b..e9ce2384 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -102,6 +102,7 @@ def test_pobtas( trans="C", device_streaming=True if array_type == "streaming" else False, ) + print("===") print(X_ref) assert xp.allclose(B, X_ref) From 8b74d46ad525825d447ba4a01ebdcd2d3eadcbb5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 20:01:07 +0000 Subject: [PATCH 127/518] another B_d print --- src/serinv/algs/pobtas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 79b832d3..8509927f 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -562,6 +562,7 @@ def _pobtas_streaming( d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) print("d2h") + print(B_previous_d) B_previous_d[(i + 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) From 43daa6870babe68030a8660e6535345d7a05fa6f Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 20:02:18 +0000 Subject: [PATCH 128/518] insert last B d2h --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 8509927f..a5146f09 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -566,7 +566,7 @@ def _pobtas_streaming( B_previous_d[(i + 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) - #B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) + B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) print(B) else: raise ValueError(f"Invalid transpose argument: {trans}.") From ac9f3d65434b29c564b383cc6c165f14105449f3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 20:03:12 +0000 Subject: [PATCH 129/518] condition last stream --- src/serinv/algs/pobtas.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index a5146f09..08fbcc3e 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -566,7 +566,8 @@ def _pobtas_streaming( B_previous_d[(i + 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) - B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) + if n_diag_blocks > 1: + B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) print(B) else: raise ValueError(f"Invalid transpose argument: {trans}.") From a74bcbe949e02198c4f82f7eea2f1364bdcbb348 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 20:15:05 +0000 Subject: [PATCH 130/518] insert wait event for last stream --- src/serinv/algs/pobtas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 08fbcc3e..ecf02748 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -567,6 +567,7 @@ def _pobtas_streaming( d2h_events[i % 2].record(stream=d2h_stream) if n_diag_blocks > 1: + d2h_stream.wait_event(compute_B_events[0]) B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) print(B) else: From 3e9644bb195c9757d5006318c9b05ae109f0538d Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 20:17:41 +0000 Subject: [PATCH 131/518] backward solve working --- src/serinv/algs/pobtas.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index ecf02748..b34fe185 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -531,10 +531,8 @@ def _pobtas_streaming( for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} - print("---") if i > 0: h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) - print("h2d") B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) @@ -544,10 +542,6 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) - print("compute") - print(B) - print(B_d) - print(B_previous_d) B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] @@ -561,15 +555,13 @@ def _pobtas_streaming( compute_B_events[i % 2].record(compute_stream) d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) - print("d2h") - print(B_previous_d) B_previous_d[(i + 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) if n_diag_blocks > 1: d2h_stream.wait_event(compute_B_events[0]) B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) - print(B) + else: raise ValueError(f"Invalid transpose argument: {trans}.") From 7f17c0fed37896d4f00e4590f56b1a9305f28d1d Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 20:20:09 +0000 Subject: [PATCH 132/518] bigger tests --- tests/tests_algs/regular/conftest.py | 1 + tests/tests_algs/regular/tests_bta/test_pobtas.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_algs/regular/conftest.py b/tests/tests_algs/regular/conftest.py index 239baed2..1a1d730f 100644 --- a/tests/tests_algs/regular/conftest.py +++ b/tests/tests_algs/regular/conftest.py @@ -9,6 +9,7 @@ pytest.param(2, id="n_diag_blocks=2"), pytest.param(3, id="n_diag_blocks=3"), pytest.param(4, id="n_diag_blocks=4"), + pytest.param(4, id="n_diag_blocks=20"), ] diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index e9ce2384..7e3c0991 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -13,7 +13,7 @@ @pytest.mark.mpi_skip() -@pytest.mark.parametrize("n_rhs", [1, 2, 3]) +@pytest.mark.parametrize("n_rhs", [1, 2, 3, 10]) def test_pobtas( n_rhs: int, diagonal_blocksize: int, From 7f87fc73743065688366cac558d5240a2c1b5fe6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 20:20:55 +0000 Subject: [PATCH 133/518] even bigger tests --- tests/tests_algs/regular/tests_bta/test_pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index 7e3c0991..2ad9d895 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -13,7 +13,7 @@ @pytest.mark.mpi_skip() -@pytest.mark.parametrize("n_rhs", [1, 2, 3, 10]) +@pytest.mark.parametrize("n_rhs", [1, 2, 3, 10, 40]) def test_pobtas( n_rhs: int, diagonal_blocksize: int, From 161613105b23c6382545cf842b2c16f92f6c7590 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 30 Apr 2025 20:21:34 +0000 Subject: [PATCH 134/518] reverted tests for now --- tests/tests_algs/regular/conftest.py | 1 - tests/tests_algs/regular/tests_bta/test_pobtas.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/tests_algs/regular/conftest.py b/tests/tests_algs/regular/conftest.py index 1a1d730f..239baed2 100644 --- a/tests/tests_algs/regular/conftest.py +++ b/tests/tests_algs/regular/conftest.py @@ -9,7 +9,6 @@ pytest.param(2, id="n_diag_blocks=2"), pytest.param(3, id="n_diag_blocks=3"), pytest.param(4, id="n_diag_blocks=4"), - pytest.param(4, id="n_diag_blocks=20"), ] diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index 2ad9d895..e9ce2384 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -13,7 +13,7 @@ @pytest.mark.mpi_skip() -@pytest.mark.parametrize("n_rhs", [1, 2, 3, 10, 40]) +@pytest.mark.parametrize("n_rhs", [1, 2, 3]) def test_pobtas( n_rhs: int, diagonal_blocksize: int, From 8335da7c384d1d2f9b9738c543b51bd6f0082aba Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 09:03:46 +0000 Subject: [PATCH 135/518] first attempt at adapted code for pobts --- src/serinv/algs/pobts.py | 124 ++++++++++++++++++ .../tests_algs/regular/tests_bt/test_pobts.py | 19 ++- 2 files changed, 142 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 3fbcd0d4..a9456cd0 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -4,6 +4,7 @@ from serinv import ( ArrayLike, _get_module_from_array, + _get_module_from_str, ) @@ -150,3 +151,126 @@ def _pobts_permuted( ) else: raise ValueError(f"Invalid transpose argument: {trans}.") + + +def _pobts_streaming( + L_diagonal_blocks: ArrayLike, + L_lower_diagonal_blocks: ArrayLike, + B: ArrayLike, + trans: str, +): + arr_module, _ = _get_module_from_array(arr=L_diagonal_blocks) + if arr_module.__name__ != "numpy": + raise NotImplementedError( + "Host<->Device streaming only works when host-arrays are given." + ) + + cp, cu_la = _get_module_from_str(module_str="cupy") + + # Vars + diag_blocksize = L_diagonal_blocks.shape[1] + n_diag_blocks = L_diagonal_blocks.shape[0] + + # Streams + compute_stream = cp.cuda.Stream(non_blocking=True) + h2d_stream = cp.cuda.Stream(non_blocking=True) + d2h_stream = cp.cuda.Stream(non_blocking=True) + + # Device Buffers + # B Buffers + B_shape = B[0 : diag_blocksize] + B_d = cp.empty( + (2, *B_shape.shape), dtype=B_shape.dtype + ) + B_previous_d = cp.empty( + (2, *B_shape.shape), dtype=B_shape.dtype + ) + + # L Buffers + L_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_lower_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + + # Events + compute_B_events = [cp.cuda.Event(), cp.cuda.Event()] + previous_B_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_events = [cp.cuda.Event(), cp.cuda.Event()] + d2h_events = [cp.cuda.Event(), cp.cuda.Event()] + + if trans == "N": + raise NotImplementedError(f"Forward solve not implemented for streaming") + + elif trans == "T" or trans == "C": + print("hi") + + B_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize], stream=h2d_stream) + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[-1], stream=h2d_stream) + + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + + with compute_stream: + B_d[(n_diag_blocks - 1) % 2] = ( + cu_la.solve_triangular( + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2], + lower=True, + trans="C", + ) + ) + + compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) + + d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) + B_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize], stream=d2h_stream, blocking=False,) + d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) + + if n_diag_blocks > 1: + + B_d[n_diag_blocks % 2].set( + arr=B[-(2 * diag_blocksize) : -diag_blocksize], + stream=h2d_stream + ) + L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) + L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) + h2d_stream.wait_event(previous_B_events[(n_diag_blocks - 1) % 2]) + B_previous_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize], stream=h2d_stream) + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + + for i in range(n_diag_blocks - 2, -1, -1): + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + if i > 0: + h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) + B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) + L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) + L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) + h2d_events[i % 2].record(stream=h2d_stream) + + with compute_stream: + compute_stream.wait_event(h2d_events[(i - 1) % 2]) + compute_stream.wait_event(d2h_events[(i - 1) % 2]) + B_previous_d[i % 2] = cu_la.solve_triangular( + L_diagonal_blocks_d[i % 2], + B_d[i % 2] + - L_lower_diagonal_blocks_d[i % 2].conj().T + @ B_previous_d[(i + 1) % 2], + lower=True, + trans="C", + ) + + compute_B_events[i % 2].record(compute_stream) + + d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) + B_previous_d[(i + 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) + d2h_events[i % 2].record(stream=d2h_stream) + + if n_diag_blocks > 1: + d2h_stream.wait_event(compute_B_events[0]) + B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) + + else: + raise ValueError(f"Invalid transpose argument: {trans}.") + + cp.cuda.Device().synchronize() \ No newline at end of file diff --git a/tests/tests_algs/regular/tests_bt/test_pobts.py b/tests/tests_algs/regular/tests_bt/test_pobts.py index 8125df52..fdc145b0 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobts.py +++ b/tests/tests_algs/regular/tests_bt/test_pobts.py @@ -3,11 +3,13 @@ import numpy as np import pytest -from serinv import _get_module_from_array +from serinv import backend_flags, _get_module_from_array from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize, rhs from serinv.algs import pobtf, pobts +if backend_flags["cupy_avail"]: + import cupyx as cpx @pytest.mark.mpi_skip() @pytest.mark.parametrize("n_rhs", [1, 2, 3]) @@ -18,6 +20,8 @@ def test_pobts( array_type: str, dtype: np.dtype, ): + array_type = "streaming" + A = dd_bt( diagonal_blocksize, n_diag_blocks, @@ -47,6 +51,18 @@ def test_pobts( _, ) = bt_dense_to_arrays(A, diagonal_blocksize, n_diag_blocks) + if backend_flags["cupy_avail"] and array_type == "streaming": + A_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_diagonal_blocks) + A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks[:, :, :] + A_lower_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_lower_diagonal_blocks) + A_lower_diagonal_blocks_pinned[:, :, :] = A_lower_diagonal_blocks[:, :, :] + B_pinned = cpx.zeros_like_pinned(B) + B_pinned[:, :] = B[:, :] + + A_diagonal_blocks = A_diagonal_blocks_pinned + A_lower_diagonal_blocks = A_lower_diagonal_blocks_pinned + B = B_pinned + pobtf( A_diagonal_blocks, A_lower_diagonal_blocks, @@ -66,6 +82,7 @@ def test_pobts( A_lower_diagonal_blocks, B, trans="C", + device_streaming=True if array_type == "streaming" else False, ) assert xp.allclose(B, X_ref) From 46627ec2d6ca4f01d1483ef18333abed74a8b245 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 09:06:23 +0000 Subject: [PATCH 136/518] removed not implemented error --- src/serinv/algs/pobts.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index a9456cd0..20028427 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -42,8 +42,11 @@ def pobts( else: # Natural arrowhead if device_streaming: - raise NotImplementedError( - "Streaming is not implemented for the natural arrowhead." + _pobts_streaming( + L_diagonal_blocks, + L_lower_diagonal_blocks, + B, + trans, ) else: _pobts( From 43afebe60776aac30e11f8df598874b9a81637ca Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 09:09:11 +0000 Subject: [PATCH 137/518] insert debug --- src/serinv/algs/pobts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 20028427..61788335 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -207,7 +207,8 @@ def _pobts_streaming( raise NotImplementedError(f"Forward solve not implemented for streaming") elif trans == "T" or trans == "C": - print("hi") + print(B_d) + print(B) B_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[-1], stream=h2d_stream) From 1b05487ecb0b11c050a006b0d337004496ab8144 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 09:10:57 +0000 Subject: [PATCH 138/518] fixed array slicing --- src/serinv/algs/pobts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 61788335..f1c80e93 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -210,7 +210,7 @@ def _pobts_streaming( print(B_d) print(B) - B_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize], stream=h2d_stream) + B_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[-1], stream=h2d_stream) h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) @@ -228,7 +228,7 @@ def _pobts_streaming( compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) - B_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize], stream=d2h_stream, blocking=False,) + B_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False,) d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) if n_diag_blocks > 1: @@ -240,7 +240,7 @@ def _pobts_streaming( L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[(n_diag_blocks - 1) % 2]) - B_previous_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize], stream=h2d_stream) + B_previous_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) for i in range(n_diag_blocks - 2, -1, -1): From 34d6577edad2867ef5d6516cceec3be103e0a282 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 09:15:26 +0000 Subject: [PATCH 139/518] pobts streaming working --- src/serinv/algs/pobts.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index f1c80e93..4a1c80f7 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -207,9 +207,6 @@ def _pobts_streaming( raise NotImplementedError(f"Forward solve not implemented for streaming") elif trans == "T" or trans == "C": - print(B_d) - print(B) - B_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[-1], stream=h2d_stream) From 42e215e712158fabf3747bd94173110e469ece4e Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 09:59:41 +0000 Subject: [PATCH 140/518] first attempt at pobts forward streaming by flipping it --- src/serinv/algs/pobts.py | 63 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 4a1c80f7..47d74bc8 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -204,7 +204,68 @@ def _pobts_streaming( d2h_events = [cp.cuda.Event(), cp.cuda.Event()] if trans == "N": - raise NotImplementedError(f"Forward solve not implemented for streaming") + B_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) + L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream) + + h2d_events[0].record(stream=h2d_stream) + + with compute_stream: + B_d[0] = ( + cu_la.solve_triangular( + L_diagonal_blocks_d[0], + B_d[0], + lower=True, + trans="C", + ) + ) + + d2h_stream.wait_event(compute_B_events[0]) + B_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False,) + d2h_events[0].record(stream=d2h_stream) + + if n_diag_blocks > 1: + + B_d[1].set( + arr=B[diag_blocksize : (2 * diag_blocksize)], + stream=h2d_stream + ) + L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) + L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[1], stream=h2d_stream) + h2d_stream.wait_event(previous_B_events[0]) + B_previous_d[0].set(arr=B[-diag_blocksize:], stream=h2d_stream) + h2d_events[0].record(stream=h2d_stream) + + for i in range(1, n_diag_blocks - 1): + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + if i + 1 < n_diag_blocks - 1: + h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) + B_d[(i + 1) % 2].set(arr=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=h2d_stream) + L_diagonal_blocks_d[(i + 1) % 2].set(arr=L_diagonal_blocks[i + 1], stream=h2d_stream) + L_lower_diagonal_blocks_d[(i + 1) % 2].set(arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream) + h2d_events[i % 2].record(stream=h2d_stream) + + with compute_stream: + compute_stream.wait_event(h2d_events[(i + 1) % 2]) + compute_stream.wait_event(d2h_events[(i + 1) % 2]) + B_previous_d[i % 2] = cu_la.solve_triangular( + L_diagonal_blocks_d[i % 2], + B_d[i % 2] + - L_lower_diagonal_blocks_d[i % 2].conj().T + @ B_previous_d[(i - 1) % 2], + lower=True, + trans="C", + ) + + compute_B_events[i % 2].record(compute_stream) + + d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) + B_previous_d[(i + 1) % 2].get(out=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=d2h_stream, blocking=False) + d2h_events[i % 2].record(stream=d2h_stream) + + if n_diag_blocks > 1: + d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) + B_previous_d[n_diag_blocks - 1].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) + elif trans == "T" or trans == "C": B_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) From abe2879d2a682980064a900bbef58f60cca3bf59 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 10:00:13 +0000 Subject: [PATCH 141/518] added test logic --- tests/tests_algs/regular/tests_bt/test_pobts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_algs/regular/tests_bt/test_pobts.py b/tests/tests_algs/regular/tests_bt/test_pobts.py index fdc145b0..f5c941dc 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobts.py +++ b/tests/tests_algs/regular/tests_bt/test_pobts.py @@ -74,6 +74,7 @@ def test_pobts( A_lower_diagonal_blocks, B, trans="N", + device_streaming=True if array_type == "streaming" else False, ) # Backward solve: X=L^{-T}Y From 96652d59a7ac72a7aef31c9ef8e2b8949c1ed076 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 10:01:44 +0000 Subject: [PATCH 142/518] changed indexing --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 47d74bc8..da666433 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -264,7 +264,7 @@ def _pobts_streaming( if n_diag_blocks > 1: d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) - B_previous_d[n_diag_blocks - 1].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) + B_previous_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) elif trans == "T" or trans == "C": From e6ce6c48f1a4df0b81edefc82e1912ef1b0d7240 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 10:06:28 +0000 Subject: [PATCH 143/518] fixed more indexing --- src/serinv/algs/pobts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index da666433..0552c974 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -230,12 +230,12 @@ def _pobts_streaming( stream=h2d_stream ) L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) - L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[1], stream=h2d_stream) + L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[0]) B_previous_d[0].set(arr=B[-diag_blocksize:], stream=h2d_stream) h2d_events[0].record(stream=h2d_stream) - for i in range(1, n_diag_blocks - 1): + for i in range(0, n_diag_blocks - 1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} if i + 1 < n_diag_blocks - 1: h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) From ea29b8ed5b21faefb6578972f6b65168100cdd30 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 10:10:16 +0000 Subject: [PATCH 144/518] switched event order --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 0552c974..cc646da1 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -230,7 +230,7 @@ def _pobts_streaming( stream=h2d_stream ) L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) - L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) + L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[0]) B_previous_d[0].set(arr=B[-diag_blocksize:], stream=h2d_stream) h2d_events[0].record(stream=h2d_stream) From 6fab517f2a1e47c87cc57ab29a11e5fbf8c9d9fb Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 10:17:30 +0000 Subject: [PATCH 145/518] changed first block logic --- src/serinv/algs/pobts.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index cc646da1..9fa7ce33 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -215,10 +215,11 @@ def _pobts_streaming( L_diagonal_blocks_d[0], B_d[0], lower=True, - trans="C", ) ) + compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) + d2h_stream.wait_event(compute_B_events[0]) B_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False,) d2h_events[0].record(stream=d2h_stream) @@ -230,7 +231,7 @@ def _pobts_streaming( stream=h2d_stream ) L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) - L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) + L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[0]) B_previous_d[0].set(arr=B[-diag_blocksize:], stream=h2d_stream) h2d_events[0].record(stream=h2d_stream) From 5e87ead868e4fc17ff18a7ca959b03355abacf2e Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 10:19:36 +0000 Subject: [PATCH 146/518] fixed solve --- src/serinv/algs/pobts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 9fa7ce33..5196e26c 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -254,7 +254,6 @@ def _pobts_streaming( - L_lower_diagonal_blocks_d[i % 2].conj().T @ B_previous_d[(i - 1) % 2], lower=True, - trans="C", ) compute_B_events[i % 2].record(compute_stream) From 82df44546e93285034e4c875fc9d55803c90868a Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 10:20:37 +0000 Subject: [PATCH 147/518] insert debug statement --- src/serinv/algs/pobts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 5196e26c..ea8876aa 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -259,6 +259,7 @@ def _pobts_streaming( compute_B_events[i % 2].record(compute_stream) d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) + print(B_previous_d) B_previous_d[(i + 1) % 2].get(out=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) From 93976bd3567617b4c01abd3642e6e742d0e97d11 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 10:24:07 +0000 Subject: [PATCH 148/518] changed lower diagonal order --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index ea8876aa..02a90690 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -231,7 +231,7 @@ def _pobts_streaming( stream=h2d_stream ) L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) - L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) + L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[0]) B_previous_d[0].set(arr=B[-diag_blocksize:], stream=h2d_stream) h2d_events[0].record(stream=h2d_stream) From 72d1b8368e731e80a88e8f7105dba0059706a3bd Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 10:26:02 +0000 Subject: [PATCH 149/518] inser debug message --- src/serinv/algs/pobts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 02a90690..3687cb11 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -260,6 +260,7 @@ def _pobts_streaming( d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) print(B_previous_d) + print(B[(i - 1) * diag_blocksize : i * diag_blocksize]) B_previous_d[(i + 1) % 2].get(out=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) From bf8077edcf4f0b56e5cc35982b981f8f9be8b279 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 10:28:30 +0000 Subject: [PATCH 150/518] changed slicing --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 3687cb11..53b011a9 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -261,7 +261,7 @@ def _pobts_streaming( d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) print(B_previous_d) print(B[(i - 1) * diag_blocksize : i * diag_blocksize]) - B_previous_d[(i + 1) % 2].get(out=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=d2h_stream, blocking=False) + B_previous_d[(i + 1) % 2].get(out=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) if n_diag_blocks > 1: From 046dffcafa662aff0d12f7a396b50e60d1ea41a5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 12:24:08 +0000 Subject: [PATCH 151/518] adjusted loop --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 53b011a9..143f4d49 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -236,7 +236,7 @@ def _pobts_streaming( B_previous_d[0].set(arr=B[-diag_blocksize:], stream=h2d_stream) h2d_events[0].record(stream=h2d_stream) - for i in range(0, n_diag_blocks - 1): + for i in range(1, n_diag_blocks - 1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} if i + 1 < n_diag_blocks - 1: h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) From f3bc5856a1690795650eac6f63adf1c0b992c8f2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 12:25:44 +0000 Subject: [PATCH 152/518] adjusted loop --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 143f4d49..041ad9fb 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -252,7 +252,7 @@ def _pobts_streaming( L_diagonal_blocks_d[i % 2], B_d[i % 2] - L_lower_diagonal_blocks_d[i % 2].conj().T - @ B_previous_d[(i - 1) % 2], + @ B_previous_d[i % 2], lower=True, ) From dfbf23b00e07db1fc0c47f7f7e0d2965377223ed Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 12:35:27 +0000 Subject: [PATCH 153/518] changed previous B --- src/serinv/algs/pobts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 041ad9fb..972d81ff 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -233,11 +233,12 @@ def _pobts_streaming( L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[0]) - B_previous_d[0].set(arr=B[-diag_blocksize:], stream=h2d_stream) + B_previous_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) h2d_events[0].record(stream=h2d_stream) for i in range(1, n_diag_blocks - 1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + if i + 1 < n_diag_blocks - 1: h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) B_d[(i + 1) % 2].set(arr=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=h2d_stream) From c59634a214fb2be022966040bc043bd89a431c2d Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 12:37:52 +0000 Subject: [PATCH 154/518] insert debug check 1 --- src/serinv/algs/pobts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 972d81ff..6f267df0 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -249,6 +249,8 @@ def _pobts_streaming( with compute_stream: compute_stream.wait_event(h2d_events[(i + 1) % 2]) compute_stream.wait_event(d2h_events[(i + 1) % 2]) + print(L_diagonal_blocks) + print(L_diagonal_blocks_d) B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] @@ -260,8 +262,6 @@ def _pobts_streaming( compute_B_events[i % 2].record(compute_stream) d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) - print(B_previous_d) - print(B[(i - 1) * diag_blocksize : i * diag_blocksize]) B_previous_d[(i + 1) % 2].get(out=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) From 5f917064e1751cc41ca89f0862d66cd311134127 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 12:39:21 +0000 Subject: [PATCH 155/518] adjusted streaming --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 6f267df0..539ec20f 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -239,7 +239,7 @@ def _pobts_streaming( for i in range(1, n_diag_blocks - 1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} - if i + 1 < n_diag_blocks - 1: + if i < n_diag_blocks - 1: h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) B_d[(i + 1) % 2].set(arr=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i + 1) % 2].set(arr=L_diagonal_blocks[i + 1], stream=h2d_stream) From adc84f87abb8a92820a91bd13f5d67e85164b083 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 12:40:15 +0000 Subject: [PATCH 156/518] adjusted streaming --- src/serinv/algs/pobts.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 539ec20f..94675fe6 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -239,12 +239,13 @@ def _pobts_streaming( for i in range(1, n_diag_blocks - 1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} - if i < n_diag_blocks - 1: - h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) - B_d[(i + 1) % 2].set(arr=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=h2d_stream) - L_diagonal_blocks_d[(i + 1) % 2].set(arr=L_diagonal_blocks[i + 1], stream=h2d_stream) + + h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) + B_d[(i + 1) % 2].set(arr=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=h2d_stream) + L_diagonal_blocks_d[(i + 1) % 2].set(arr=L_diagonal_blocks[i + 1], stream=h2d_stream) + if i + 1 < n_diag_blocks - 1: L_lower_diagonal_blocks_d[(i + 1) % 2].set(arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream) - h2d_events[i % 2].record(stream=h2d_stream) + h2d_events[i % 2].record(stream=h2d_stream) with compute_stream: compute_stream.wait_event(h2d_events[(i + 1) % 2]) From 1409c5d4cb3db62433df7ac266f7fa0f00b7a4d7 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 12:41:44 +0000 Subject: [PATCH 157/518] insert more debug --- src/serinv/algs/pobts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 94675fe6..edf61cbf 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -252,6 +252,7 @@ def _pobts_streaming( compute_stream.wait_event(d2h_events[(i + 1) % 2]) print(L_diagonal_blocks) print(L_diagonal_blocks_d) + print(i % 2) B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] From 6e13d6f01ef94e96bae47aa40c250f828ee24782 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 12:42:45 +0000 Subject: [PATCH 158/518] expanded for loop --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index edf61cbf..9f2101a2 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -236,7 +236,7 @@ def _pobts_streaming( B_previous_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) h2d_events[0].record(stream=h2d_stream) - for i in range(1, n_diag_blocks - 1): + for i in range(1, n_diag_blocks): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} From d856785f66341c501040e645d8fb9ec390a7efe0 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 12:43:23 +0000 Subject: [PATCH 159/518] adjusted streaming --- src/serinv/algs/pobts.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 9f2101a2..cd37581d 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -239,13 +239,13 @@ def _pobts_streaming( for i in range(1, n_diag_blocks): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} - - h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) - B_d[(i + 1) % 2].set(arr=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=h2d_stream) - L_diagonal_blocks_d[(i + 1) % 2].set(arr=L_diagonal_blocks[i + 1], stream=h2d_stream) - if i + 1 < n_diag_blocks - 1: - L_lower_diagonal_blocks_d[(i + 1) % 2].set(arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream) - h2d_events[i % 2].record(stream=h2d_stream) + if i < n_diag_blocks - 1: + h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) + B_d[(i + 1) % 2].set(arr=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=h2d_stream) + L_diagonal_blocks_d[(i + 1) % 2].set(arr=L_diagonal_blocks[i + 1], stream=h2d_stream) + if i + 1 < n_diag_blocks - 1: + L_lower_diagonal_blocks_d[(i + 1) % 2].set(arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream) + h2d_events[i % 2].record(stream=h2d_stream) with compute_stream: compute_stream.wait_event(h2d_events[(i + 1) % 2]) From f46a64a1c6a18c49f101cad648cdfa0eebc9696b Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 12:44:34 +0000 Subject: [PATCH 160/518] check number 2 --- src/serinv/algs/pobts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index cd37581d..e407374b 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -250,8 +250,8 @@ def _pobts_streaming( with compute_stream: compute_stream.wait_event(h2d_events[(i + 1) % 2]) compute_stream.wait_event(d2h_events[(i + 1) % 2]) - print(L_diagonal_blocks) - print(L_diagonal_blocks_d) + print(L_lower_diagonal_blocks) + print(L_lower_diagonal_blocks_d) print(i % 2) B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], From 871a3b75f8236c9c60c89afdaedbadcca32113fe Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 12:49:05 +0000 Subject: [PATCH 161/518] shifted indexing --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index e407374b..b307daee 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -233,7 +233,7 @@ def _pobts_streaming( L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[0]) - B_previous_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) + B_previous_d[1].set(arr=B[:diag_blocksize], stream=h2d_stream) h2d_events[0].record(stream=h2d_stream) for i in range(1, n_diag_blocks): From 6f4971cee1c7d8ec8c32dd983b54b5238af06c8e Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 12:53:54 +0000 Subject: [PATCH 162/518] changed lower streaming --- src/serinv/algs/pobts.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index b307daee..dbd74e0f 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -243,8 +243,7 @@ def _pobts_streaming( h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) B_d[(i + 1) % 2].set(arr=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i + 1) % 2].set(arr=L_diagonal_blocks[i + 1], stream=h2d_stream) - if i + 1 < n_diag_blocks - 1: - L_lower_diagonal_blocks_d[(i + 1) % 2].set(arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream) + L_lower_diagonal_blocks_d[(i + 1) % 2].set(arr=L_lower_diagonal_blocks[i], stream=h2d_stream) h2d_events[i % 2].record(stream=h2d_stream) with compute_stream: From 8307f51b509bdacfcaa373fb3eadc06cce36c680 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 13:06:30 +0000 Subject: [PATCH 163/518] more debug --- src/serinv/algs/pobts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index dbd74e0f..bd1fbfc9 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -239,7 +239,8 @@ def _pobts_streaming( for i in range(1, n_diag_blocks): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} - if i < n_diag_blocks - 1: + if i + 1 < n_diag_blocks: + print(i) h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) B_d[(i + 1) % 2].set(arr=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i + 1) % 2].set(arr=L_diagonal_blocks[i + 1], stream=h2d_stream) From 742dcd390983d3712285b6916d2a9815490f6cce Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 13:09:29 +0000 Subject: [PATCH 164/518] removed some debug --- src/serinv/algs/pobts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index bd1fbfc9..3a61135f 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -240,7 +240,6 @@ def _pobts_streaming( # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} if i + 1 < n_diag_blocks: - print(i) h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) B_d[(i + 1) % 2].set(arr=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i + 1) % 2].set(arr=L_diagonal_blocks[i + 1], stream=h2d_stream) From c4f3fed9f010702b605791e470e0e01fc8ea9901 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 13:10:41 +0000 Subject: [PATCH 165/518] debug number 3 --- src/serinv/algs/pobts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 3a61135f..d6e60a9c 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -249,8 +249,8 @@ def _pobts_streaming( with compute_stream: compute_stream.wait_event(h2d_events[(i + 1) % 2]) compute_stream.wait_event(d2h_events[(i + 1) % 2]) - print(L_lower_diagonal_blocks) - print(L_lower_diagonal_blocks_d) + print(B) + print(B_d) print(i % 2) B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], From 620ee3b91b8c1e52d00a31b13023470736d58b27 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 13:52:21 +0000 Subject: [PATCH 166/518] changed B streaming --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index d6e60a9c..97584795 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -241,7 +241,7 @@ def _pobts_streaming( if i + 1 < n_diag_blocks: h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) - B_d[(i + 1) % 2].set(arr=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=h2d_stream) + B_d[(i + 1) % 2].set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i + 1) % 2].set(arr=L_diagonal_blocks[i + 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i + 1) % 2].set(arr=L_lower_diagonal_blocks[i], stream=h2d_stream) h2d_events[i % 2].record(stream=h2d_stream) From 82b8190a42e19e68a6992cacb49295b1e23c9504 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 13:54:06 +0000 Subject: [PATCH 167/518] more changes to B streaming --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 97584795..0dd1a44f 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -233,7 +233,7 @@ def _pobts_streaming( L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[0]) - B_previous_d[1].set(arr=B[:diag_blocksize], stream=h2d_stream) + B_previous_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) h2d_events[0].record(stream=h2d_stream) for i in range(1, n_diag_blocks): From f88aad87e509d2c05922c080b5ce15467a86a2fd Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 13:55:10 +0000 Subject: [PATCH 168/518] changed B previous --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 0dd1a44f..71092252 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -256,7 +256,7 @@ def _pobts_streaming( L_diagonal_blocks_d[i % 2], B_d[i % 2] - L_lower_diagonal_blocks_d[i % 2].conj().T - @ B_previous_d[i % 2], + @ B_previous_d[(i + 1) % 2], lower=True, ) From 9e401ac6b8173f64a191ceb0547b9545fbb3e06b Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 13:58:04 +0000 Subject: [PATCH 169/518] removed wrong transposition --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 71092252..714706b1 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -255,7 +255,7 @@ def _pobts_streaming( B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] - - L_lower_diagonal_blocks_d[i % 2].conj().T + - L_lower_diagonal_blocks_d[i % 2] @ B_previous_d[(i + 1) % 2], lower=True, ) From 7c24151821669e4ba3760e2081480a7dea5810f6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 14:44:07 +0000 Subject: [PATCH 170/518] debug check 4 --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 714706b1..edb1dbaf 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -250,7 +250,7 @@ def _pobts_streaming( compute_stream.wait_event(h2d_events[(i + 1) % 2]) compute_stream.wait_event(d2h_events[(i + 1) % 2]) print(B) - print(B_d) + print(B_previous_d) print(i % 2) B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], From eff1bdab4635c2409fa473280ed51231fa378932 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 14:56:17 +0000 Subject: [PATCH 171/518] debug b previous --- src/serinv/algs/pobts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index edb1dbaf..38512951 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -302,6 +302,7 @@ def _pobts_streaming( L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[(n_diag_blocks - 1) % 2]) + print(B_previous_d) B_previous_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) From c021ca075bed0747a83f8ed07c8e9ab8b6e18fb7 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 14:57:12 +0000 Subject: [PATCH 172/518] moved debug message --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 38512951..f5127092 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -233,6 +233,7 @@ def _pobts_streaming( L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[0]) + print(B_previous_d) B_previous_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) h2d_events[0].record(stream=h2d_stream) @@ -302,7 +303,6 @@ def _pobts_streaming( L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[(n_diag_blocks - 1) % 2]) - print(B_previous_d) B_previous_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) From 77941229c3f4b30a19934985a58c07d152d9fe7f Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:00:05 +0000 Subject: [PATCH 173/518] shift B previous get --- src/serinv/algs/pobts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index f5127092..b2641b20 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -257,14 +257,14 @@ def _pobts_streaming( L_diagonal_blocks_d[i % 2], B_d[i % 2] - L_lower_diagonal_blocks_d[i % 2] - @ B_previous_d[(i + 1) % 2], + @ B_previous_d[(i - 1) % 2], lower=True, ) compute_B_events[i % 2].record(compute_stream) d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) - B_previous_d[(i + 1) % 2].get(out=B[i * diag_blocksize : (i + 1) * diag_blocksize], stream=d2h_stream, blocking=False) + B_previous_d[(i - 1) % 2].get(out=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) if n_diag_blocks > 1: From 3e6d3c3246162eb62279ab4b9cd620aa82e63131 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:03:04 +0000 Subject: [PATCH 174/518] changed last B --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index b2641b20..4f3ee4a5 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -268,7 +268,7 @@ def _pobts_streaming( d2h_events[i % 2].record(stream=d2h_stream) if n_diag_blocks > 1: - d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) + d2h_stream.wait_event(compute_B_events[(n_diag_blocks) % 2]) B_previous_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) From e5fb88a7130d5e17ceb9731e1710dfce100e80c7 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:07:26 +0000 Subject: [PATCH 175/518] test for last B --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 4f3ee4a5..c564df58 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -269,7 +269,7 @@ def _pobts_streaming( if n_diag_blocks > 1: d2h_stream.wait_event(compute_B_events[(n_diag_blocks) % 2]) - B_previous_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) + # B_previous_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) elif trans == "T" or trans == "C": From 7eeb5c198ea2cc5d832ed83f22bd1d445077f3c3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:07:47 +0000 Subject: [PATCH 176/518] revert --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index c564df58..4f3ee4a5 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -269,7 +269,7 @@ def _pobts_streaming( if n_diag_blocks > 1: d2h_stream.wait_event(compute_B_events[(n_diag_blocks) % 2]) - # B_previous_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) + B_previous_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) elif trans == "T" or trans == "C": From 2afe74b81e798360f96ffb23ce9e9cebdbe9fb45 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:10:16 +0000 Subject: [PATCH 177/518] try different stream order --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 4f3ee4a5..7e5f5564 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -235,7 +235,7 @@ def _pobts_streaming( h2d_stream.wait_event(previous_B_events[0]) print(B_previous_d) B_previous_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) - h2d_events[0].record(stream=h2d_stream) + h2d_events[1].record(stream=h2d_stream) for i in range(1, n_diag_blocks): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} From c06b8496b434fb6620ff8e5d7e839d1217c1a666 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:13:11 +0000 Subject: [PATCH 178/518] insert failsafe --- src/serinv/algs/pobts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 7e5f5564..d935de1d 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -210,6 +210,7 @@ def _pobts_streaming( h2d_events[0].record(stream=h2d_stream) with compute_stream: + compute_stream.wait_event(h2d_events[0]) B_d[0] = ( cu_la.solve_triangular( L_diagonal_blocks_d[0], @@ -235,7 +236,7 @@ def _pobts_streaming( h2d_stream.wait_event(previous_B_events[0]) print(B_previous_d) B_previous_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) - h2d_events[1].record(stream=h2d_stream) + h2d_events[0].record(stream=h2d_stream) for i in range(1, n_diag_blocks): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} From 01f67d16e744d76f43b5a1b5b2a72e4d01d2fe22 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:17:01 +0000 Subject: [PATCH 179/518] more failsafe --- src/serinv/algs/pobts.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index d935de1d..f2885162 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -207,10 +207,10 @@ def _pobts_streaming( B_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream) - h2d_events[0].record(stream=h2d_stream) + h2d_events[1].record(stream=h2d_stream) with compute_stream: - compute_stream.wait_event(h2d_events[0]) + compute_stream.wait_event(h2d_events[1]) B_d[0] = ( cu_la.solve_triangular( L_diagonal_blocks_d[0], @@ -234,7 +234,6 @@ def _pobts_streaming( L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) h2d_stream.wait_event(previous_B_events[0]) - print(B_previous_d) B_previous_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) h2d_events[0].record(stream=h2d_stream) From 5be4c6fbc941094cd3927d2c00cb6fbd46ad1b61 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:19:48 +0000 Subject: [PATCH 180/518] removed unnecessary events --- src/serinv/algs/pobts.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index f2885162..1d852b1a 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -233,7 +233,6 @@ def _pobts_streaming( ) L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) - h2d_stream.wait_event(previous_B_events[0]) B_previous_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) h2d_events[0].record(stream=h2d_stream) @@ -302,7 +301,6 @@ def _pobts_streaming( ) L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) - h2d_stream.wait_event(previous_B_events[(n_diag_blocks - 1) % 2]) B_previous_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) From e65378058923dc8fb1799cd0adcd06c08c7558c4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:22:51 +0000 Subject: [PATCH 181/518] stream failsafes --- src/serinv/algs/pobts.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 1d852b1a..5e110f9e 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -199,7 +199,6 @@ def _pobts_streaming( # Events compute_B_events = [cp.cuda.Event(), cp.cuda.Event()] - previous_B_events = [cp.cuda.Event(), cp.cuda.Event()] h2d_events = [cp.cuda.Event(), cp.cuda.Event()] d2h_events = [cp.cuda.Event(), cp.cuda.Event()] @@ -226,7 +225,7 @@ def _pobts_streaming( d2h_events[0].record(stream=d2h_stream) if n_diag_blocks > 1: - + h2d_stream.wait_event(d2h_events[0]) B_d[1].set( arr=B[diag_blocksize : (2 * diag_blocksize)], stream=h2d_stream From 93b669a425e536ed745f3ab919dd8299a7d60b02 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:23:54 +0000 Subject: [PATCH 182/518] more failsafe --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 5e110f9e..18ed355e 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -293,7 +293,7 @@ def _pobts_streaming( d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) if n_diag_blocks > 1: - + h2d_stream.wait_event(d2h_events[(n_diag_blocks - 1) % 2]) B_d[n_diag_blocks % 2].set( arr=B[-(2 * diag_blocksize) : -diag_blocksize], stream=h2d_stream From c6fc65fb8a28f39d3ebf95efc5efbd08879a6795 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:31:26 +0000 Subject: [PATCH 183/518] changed faulty event --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 18ed355e..3d108daf 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -218,7 +218,7 @@ def _pobts_streaming( ) ) - compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) + compute_B_events[0].record(stream=compute_stream) d2h_stream.wait_event(compute_B_events[0]) B_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False,) From 12ca6405cc0c42c54f2c9d437e9b089eef349a01 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:34:25 +0000 Subject: [PATCH 184/518] changed last stream --- src/serinv/algs/pobts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 3d108daf..e7539abe 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -266,7 +266,7 @@ def _pobts_streaming( d2h_events[i % 2].record(stream=d2h_stream) if n_diag_blocks > 1: - d2h_stream.wait_event(compute_B_events[(n_diag_blocks) % 2]) + d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) B_previous_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) From 5efb2885b028881824279f26e80c41a281891167 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:35:57 +0000 Subject: [PATCH 185/518] removed unnecessary events --- src/serinv/algs/pobtas.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index b34fe185..0d09698c 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -468,7 +468,6 @@ def _pobtas_streaming( # Events compute_B_events = [cp.cuda.Event(), cp.cuda.Event()] - previous_B_events = [cp.cuda.Event(), cp.cuda.Event()] h2d_events = [cp.cuda.Event(), cp.cuda.Event()] d2h_events = [cp.cuda.Event(), cp.cuda.Event()] @@ -525,7 +524,6 @@ def _pobtas_streaming( L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) - h2d_stream.wait_event(previous_B_events[(n_diag_blocks - 1) % 2]) B_previous_d[(n_diag_blocks - 1) % 2].set(arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=h2d_stream) h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) From 18943347fa62c362855fca497a489baba9640034 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:39:03 +0000 Subject: [PATCH 186/518] more parity --- src/serinv/algs/pobtas.py | 4 ++-- src/serinv/algs/pobts.py | 16 +++++++--------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 0d09698c..4f93df73 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -544,7 +544,7 @@ def _pobtas_streaming( L_diagonal_blocks_d[i % 2], B_d[i % 2] - L_lower_diagonal_blocks_d[i % 2].conj().T - @ B_previous_d[(i + 1) % 2] + @ B_previous_d[(i - 1) % 2] - L_lower_arrow_blocks_d[i % 2].conj().T @ B_arrow_tip_d, lower=True, trans="C", @@ -553,7 +553,7 @@ def _pobtas_streaming( compute_B_events[i % 2].record(compute_stream) d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) - B_previous_d[(i + 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) + B_previous_d[(i - 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) if n_diag_blocks > 1: diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index e7539abe..2c04df7a 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -248,26 +248,23 @@ def _pobts_streaming( with compute_stream: compute_stream.wait_event(h2d_events[(i + 1) % 2]) compute_stream.wait_event(d2h_events[(i + 1) % 2]) - print(B) - print(B_previous_d) - print(i % 2) B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] - L_lower_diagonal_blocks_d[i % 2] - @ B_previous_d[(i - 1) % 2], + @ B_previous_d[(i + 1) % 2], lower=True, ) compute_B_events[i % 2].record(compute_stream) d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) - B_previous_d[(i - 1) % 2].get(out=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=d2h_stream, blocking=False) + B_previous_d[(i + 1) % 2].get(out=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) if n_diag_blocks > 1: - d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) - B_previous_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) + d2h_stream.wait_event(compute_B_events[(n_diag_blocks + 1) % 2]) + B_previous_d[(n_diag_blocks + 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) elif trans == "T" or trans == "C": @@ -277,6 +274,7 @@ def _pobts_streaming( h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) with compute_stream: + compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) B_d[(n_diag_blocks - 1) % 2] = ( cu_la.solve_triangular( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], @@ -319,7 +317,7 @@ def _pobts_streaming( L_diagonal_blocks_d[i % 2], B_d[i % 2] - L_lower_diagonal_blocks_d[i % 2].conj().T - @ B_previous_d[(i + 1) % 2], + @ B_previous_d[(i - 1) % 2], lower=True, trans="C", ) @@ -327,7 +325,7 @@ def _pobts_streaming( compute_B_events[i % 2].record(compute_stream) d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) - B_previous_d[(i + 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) + B_previous_d[(i - 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) if n_diag_blocks > 1: From 9db222d33fb8ae2b0c6fb327ec5bab7ab8f3a76c Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:40:21 +0000 Subject: [PATCH 187/518] more failsafes --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 4f93df73..0492802b 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -516,7 +516,7 @@ def _pobtas_streaming( if n_diag_blocks > 1: - + h2d_stream.wait_event(d2h_events[(n_diag_blocks - 1) % 2]) B_d[n_diag_blocks % 2].set( arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], stream=h2d_stream From db2928dc25f26d1b986107d09014949e7abc38e2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:46:42 +0000 Subject: [PATCH 188/518] cosmetic changes --- src/serinv/algs/pobtas.py | 56 ++++++++++++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 0492802b..556f267f 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -429,34 +429,46 @@ def _pobtas_streaming( # arrow tip block of the RHS. h2d_stream.wait_event(d2h_tip_events[n_diag_blocks % 2]) + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream) + h2d_diagonal_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) + h2d_arrow_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) with compute_stream: compute_stream.wait_event(h2d_diagonal_events[(n_diag_blocks - 1) % 2]) - B_d[(n_diag_blocks - 1) % 2] = (cu_la.solve_triangular(L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], B_d[(n_diag_blocks - 1) % 2], lower=True,)) + B_d[(n_diag_blocks - 1) % 2] = cu_la.solve_triangular( + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2], + lower=True + ) compute_partial_events[0].record(stream=compute_stream) d2h_stream.wait_event(compute_partial_events[0]) - B_d[(n_diag_blocks - 1) % 2].get(out=B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize], stream=d2h_stream, blocking=False,) + + B_d[(n_diag_blocks - 1) % 2].get( + out=B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize], + stream=d2h_stream, + blocking=False + ) + d2h_B_events[0].record(stream=d2h_stream) with compute_stream: compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) B_arrow_tip_d -= (L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2] @ B_d[(n_diag_blocks - 1) % 2]) + B_arrow_tip_d = cu_la.solve_triangular(L_arrow_tip_block_d, B_arrow_tip_d, lower=True) + compute_partial_events[1].record(stream=compute_stream) - compute_stream.wait_event(compute_partial_events[1]) - B_arrow_tip_d = cu_la.solve_triangular(L_arrow_tip_block_d, B_arrow_tip_d, lower=True) - compute_partial_events[0].record(stream=compute_stream) + d2h_stream.wait_event(compute_partial_events[1]) - d2h_stream.wait_event(compute_partial_events[0]) B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) elif trans == "T" or trans == "C": @@ -476,7 +488,10 @@ def _pobtas_streaming( B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:], stream=h2d_stream) - B_d[(n_diag_blocks - 1) % 2].set(arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=h2d_stream) + B_d[(n_diag_blocks - 1) % 2].set( + arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], + stream=h2d_stream + ) L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[-1], stream=h2d_stream) L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) @@ -510,36 +525,50 @@ def _pobtas_streaming( compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) + B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - B_d[(n_diag_blocks - 1) % 2].get(out=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=d2h_stream, blocking=False,) + B_d[(n_diag_blocks - 1) % 2].get( + out=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], + stream=d2h_stream, + blocking=False + + ) d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) if n_diag_blocks > 1: h2d_stream.wait_event(d2h_events[(n_diag_blocks - 1) % 2]) + B_d[n_diag_blocks % 2].set( arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], stream=h2d_stream ) + B_previous_d[(n_diag_blocks - 1) % 2].set( + arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], + stream=h2d_stream + ) L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) - B_previous_d[(n_diag_blocks - 1) % 2].set(arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=h2d_stream) + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} if i > 0: h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) + B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_arrow_blocks_d[(i - 1) % 2].set(arr=L_lower_arrow_blocks[i - 1], stream=h2d_stream) + h2d_events[i % 2].record(stream=h2d_stream) with compute_stream: compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) + B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] @@ -553,11 +582,18 @@ def _pobtas_streaming( compute_B_events[i % 2].record(compute_stream) d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) - B_previous_d[(i - 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) + + B_previous_d[(i - 1) % 2].get( + out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream=d2h_stream, + blocking=False + + ) d2h_events[i % 2].record(stream=d2h_stream) if n_diag_blocks > 1: d2h_stream.wait_event(compute_B_events[0]) + B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) else: From 47c9f5cffe2e04c19fe524024020de2d6f2c8955 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:54:49 +0000 Subject: [PATCH 189/518] more cosmetic changes --- src/serinv/algs/pobtas.py | 2 ++ src/serinv/algs/pobts.py | 31 +++++++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 556f267f..f4d692c6 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -420,7 +420,9 @@ def _pobtas_streaming( compute_arrow_B_events[i % 2].record(stream=compute_stream) d2h_stream.wait_event(compute_arrow_B_events[i % 2]) + B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + d2h_tip_events[i % 2].record(stream=d2h_stream) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 2c04df7a..351838d6 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -188,6 +188,7 @@ def _pobts_streaming( B_previous_d = cp.empty( (2, *B_shape.shape), dtype=B_shape.dtype ) + del B_shape # L Buffers L_diagonal_blocks_d = cp.empty( @@ -210,6 +211,7 @@ def _pobts_streaming( with compute_stream: compute_stream.wait_event(h2d_events[1]) + B_d[0] = ( cu_la.solve_triangular( L_diagonal_blocks_d[0], @@ -221,11 +223,14 @@ def _pobts_streaming( compute_B_events[0].record(stream=compute_stream) d2h_stream.wait_event(compute_B_events[0]) + B_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False,) + d2h_events[0].record(stream=d2h_stream) if n_diag_blocks > 1: h2d_stream.wait_event(d2h_events[0]) + B_d[1].set( arr=B[diag_blocksize : (2 * diag_blocksize)], stream=h2d_stream @@ -233,6 +238,7 @@ def _pobts_streaming( L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) B_previous_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) + h2d_events[0].record(stream=h2d_stream) for i in range(1, n_diag_blocks): @@ -240,14 +246,17 @@ def _pobts_streaming( if i + 1 < n_diag_blocks: h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) + B_d[(i + 1) % 2].set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i + 1) % 2].set(arr=L_diagonal_blocks[i + 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i + 1) % 2].set(arr=L_lower_diagonal_blocks[i], stream=h2d_stream) + h2d_events[i % 2].record(stream=h2d_stream) with compute_stream: compute_stream.wait_event(h2d_events[(i + 1) % 2]) compute_stream.wait_event(d2h_events[(i + 1) % 2]) + B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] @@ -259,11 +268,18 @@ def _pobts_streaming( compute_B_events[i % 2].record(compute_stream) d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) - B_previous_d[(i + 1) % 2].get(out=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=d2h_stream, blocking=False) + + B_previous_d[(i + 1) % 2].get( + out=B[(i - 1) * diag_blocksize : i * diag_blocksize], + stream=d2h_stream, + blocking=False + ) + d2h_events[i % 2].record(stream=d2h_stream) if n_diag_blocks > 1: d2h_stream.wait_event(compute_B_events[(n_diag_blocks + 1) % 2]) + B_previous_d[(n_diag_blocks + 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) @@ -275,6 +291,7 @@ def _pobts_streaming( with compute_stream: compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) + B_d[(n_diag_blocks - 1) % 2] = ( cu_la.solve_triangular( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], @@ -287,11 +304,14 @@ def _pobts_streaming( compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) - B_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False,) + + B_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) + d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) if n_diag_blocks > 1: h2d_stream.wait_event(d2h_events[(n_diag_blocks - 1) % 2]) + B_d[n_diag_blocks % 2].set( arr=B[-(2 * diag_blocksize) : -diag_blocksize], stream=h2d_stream @@ -299,20 +319,24 @@ def _pobts_streaming( L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) B_previous_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} if i > 0: h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) + B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) + h2d_events[i % 2].record(stream=h2d_stream) with compute_stream: compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) + B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] @@ -325,11 +349,14 @@ def _pobts_streaming( compute_B_events[i % 2].record(compute_stream) d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) + B_previous_d[(i - 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) + d2h_events[i % 2].record(stream=d2h_stream) if n_diag_blocks > 1: d2h_stream.wait_event(compute_B_events[0]) + B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) else: From 3ba9b9fc6bc720c629bb3b86509533b26bd3f921 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:55:58 +0000 Subject: [PATCH 190/518] attempt to reduce streaming --- src/serinv/algs/pobts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 351838d6..67fb2c95 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -292,7 +292,7 @@ def _pobts_streaming( with compute_stream: compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) - B_d[(n_diag_blocks - 1) % 2] = ( + B_previous_d[(n_diag_blocks - 1) % 2] = ( cu_la.solve_triangular( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], B_d[(n_diag_blocks - 1) % 2], @@ -305,7 +305,7 @@ def _pobts_streaming( d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) - B_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) + B_previous_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) @@ -318,7 +318,7 @@ def _pobts_streaming( ) L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) - B_previous_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) + #B_previous_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) From 3d9e334be789fb12f00989ed4cd8bc89b9a13de2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 15:59:26 +0000 Subject: [PATCH 191/518] reduced streaming --- src/serinv/algs/pobtas.py | 8 ++------ src/serinv/algs/pobts.py | 6 ++---- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index f4d692c6..0744e7df 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -514,7 +514,7 @@ def _pobtas_streaming( trans="C", ) - B_d[(n_diag_blocks - 1) % 2] = ( + B_previous_d[(n_diag_blocks - 1) % 2] = ( cu_la.solve_triangular( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], B_d[(n_diag_blocks - 1) % 2] @@ -529,7 +529,7 @@ def _pobtas_streaming( d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - B_d[(n_diag_blocks - 1) % 2].get( + B_previous_d[(n_diag_blocks - 1) % 2].get( out=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], stream=d2h_stream, blocking=False @@ -545,10 +545,6 @@ def _pobtas_streaming( arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], stream=h2d_stream ) - B_previous_d[(n_diag_blocks - 1) % 2].set( - arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], - stream=h2d_stream - ) L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 67fb2c95..30aac43c 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -212,7 +212,7 @@ def _pobts_streaming( with compute_stream: compute_stream.wait_event(h2d_events[1]) - B_d[0] = ( + B_previous_d[0] = ( cu_la.solve_triangular( L_diagonal_blocks_d[0], B_d[0], @@ -224,7 +224,7 @@ def _pobts_streaming( d2h_stream.wait_event(compute_B_events[0]) - B_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False,) + B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False,) d2h_events[0].record(stream=d2h_stream) @@ -237,7 +237,6 @@ def _pobts_streaming( ) L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) - B_previous_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) h2d_events[0].record(stream=h2d_stream) @@ -318,7 +317,6 @@ def _pobts_streaming( ) L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) - #B_previous_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) From 8bf1908c9a95986bd34d2b35c878da629fcef7f4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 16:01:14 +0000 Subject: [PATCH 192/518] attempt to reduce streaming --- src/serinv/algs/pobts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 30aac43c..5627d6dd 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -302,9 +302,9 @@ def _pobts_streaming( compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) - d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) + #d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) - B_previous_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) + #B_previous_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) @@ -352,7 +352,7 @@ def _pobts_streaming( d2h_events[i % 2].record(stream=d2h_stream) - if n_diag_blocks > 1: + if n_diag_blocks > 0: d2h_stream.wait_event(compute_B_events[0]) B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) From 25c9e56d885bd4f5d2e963194a43da651b204fdd Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 16:02:40 +0000 Subject: [PATCH 193/518] parity reduced streaming --- src/serinv/algs/pobts.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 5627d6dd..dd3459b2 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -222,9 +222,9 @@ def _pobts_streaming( compute_B_events[0].record(stream=compute_stream) - d2h_stream.wait_event(compute_B_events[0]) + #d2h_stream.wait_event(compute_B_events[0]) - B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False,) + #B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False,) d2h_events[0].record(stream=d2h_stream) @@ -276,10 +276,9 @@ def _pobts_streaming( d2h_events[i % 2].record(stream=d2h_stream) - if n_diag_blocks > 1: - d2h_stream.wait_event(compute_B_events[(n_diag_blocks + 1) % 2]) + d2h_stream.wait_event(compute_B_events[(n_diag_blocks + 1) % 2]) - B_previous_d[(n_diag_blocks + 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) + B_previous_d[(n_diag_blocks + 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) elif trans == "T" or trans == "C": @@ -302,10 +301,6 @@ def _pobts_streaming( compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) - #d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) - - #B_previous_d[(n_diag_blocks - 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) - d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) if n_diag_blocks > 1: @@ -352,10 +347,10 @@ def _pobts_streaming( d2h_events[i % 2].record(stream=d2h_stream) - if n_diag_blocks > 0: - d2h_stream.wait_event(compute_B_events[0]) + + d2h_stream.wait_event(compute_B_events[0]) - B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) + B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) else: raise ValueError(f"Invalid transpose argument: {trans}.") From 83a681cf35c16699e9a1e6067657b81392e6413c Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 16:04:46 +0000 Subject: [PATCH 194/518] attempt to fuirther reduce streaming --- src/serinv/algs/pobtas.py | 17 ++++++++--------- src/serinv/algs/pobts.py | 4 ---- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 0744e7df..16a54ee7 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -529,12 +529,11 @@ def _pobtas_streaming( d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - B_previous_d[(n_diag_blocks - 1) % 2].get( - out=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], - stream=d2h_stream, - blocking=False - - ) + #B_previous_d[(n_diag_blocks - 1) % 2].get( + # out=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], + # stream=d2h_stream, + # blocking=False + #) d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) @@ -589,10 +588,10 @@ def _pobtas_streaming( ) d2h_events[i % 2].record(stream=d2h_stream) - if n_diag_blocks > 1: - d2h_stream.wait_event(compute_B_events[0]) + #if n_diag_blocks > 1: + d2h_stream.wait_event(compute_B_events[0]) - B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) + B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) else: raise ValueError(f"Invalid transpose argument: {trans}.") diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index dd3459b2..2f216173 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -221,10 +221,6 @@ def _pobts_streaming( ) compute_B_events[0].record(stream=compute_stream) - - #d2h_stream.wait_event(compute_B_events[0]) - - #B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False,) d2h_events[0].record(stream=d2h_stream) From a685906a1e3c1a451b9f27d8dc2bd235bfe826c2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 16:09:31 +0000 Subject: [PATCH 195/518] speed up setup attempt --- src/serinv/algs/pobtas.py | 7 +------ src/serinv/algs/pobts.py | 4 ++-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 16a54ee7..2e69178d 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -529,11 +529,7 @@ def _pobtas_streaming( d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - #B_previous_d[(n_diag_blocks - 1) % 2].get( - # out=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], - # stream=d2h_stream, - # blocking=False - #) + d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) @@ -588,7 +584,6 @@ def _pobtas_streaming( ) d2h_events[i % 2].record(stream=d2h_stream) - #if n_diag_blocks > 1: d2h_stream.wait_event(compute_B_events[0]) B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 2f216173..3278f94f 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -297,10 +297,10 @@ def _pobts_streaming( compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) - d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) + #d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) if n_diag_blocks > 1: - h2d_stream.wait_event(d2h_events[(n_diag_blocks - 1) % 2]) + #h2d_stream.wait_event(d2h_events[(n_diag_blocks - 1) % 2]) B_d[n_diag_blocks % 2].set( arr=B[-(2 * diag_blocksize) : -diag_blocksize], From 37118fd9a319858caf0e6a0d6ea694f5a59366fa Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 16:11:06 +0000 Subject: [PATCH 196/518] expand delay reduction --- src/serinv/algs/pobtas.py | 4 ++-- src/serinv/algs/pobts.py | 6 ------ 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 2e69178d..6798e6f4 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -530,11 +530,11 @@ def _pobtas_streaming( B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) + #d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) if n_diag_blocks > 1: - h2d_stream.wait_event(d2h_events[(n_diag_blocks - 1) % 2]) + #h2d_stream.wait_event(d2h_events[(n_diag_blocks - 1) % 2]) B_d[n_diag_blocks % 2].set( arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 3278f94f..9b07c435 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -221,11 +221,8 @@ def _pobts_streaming( ) compute_B_events[0].record(stream=compute_stream) - - d2h_events[0].record(stream=d2h_stream) if n_diag_blocks > 1: - h2d_stream.wait_event(d2h_events[0]) B_d[1].set( arr=B[diag_blocksize : (2 * diag_blocksize)], @@ -297,10 +294,7 @@ def _pobts_streaming( compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) - #d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) - if n_diag_blocks > 1: - #h2d_stream.wait_event(d2h_events[(n_diag_blocks - 1) % 2]) B_d[n_diag_blocks % 2].set( arr=B[-(2 * diag_blocksize) : -diag_blocksize], From 73996a28263b1a32199e185e14bedae1652d57e5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 16:35:09 +0000 Subject: [PATCH 197/518] comment changes --- src/serinv/algs/pobtas.py | 18 ++++++------------ src/serinv/algs/pobts.py | 7 +++++-- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 6798e6f4..a457528c 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -296,7 +296,8 @@ def _pobtas_streaming( compute_partial_events = [cp.cuda.Event(), cp.cuda.Event()] - # Forward Pass + # --- Forward substitution --- + # --- C: events + transfers --- compute_current_B_events[1].record(stream=compute_stream) compute_next_B_events[1].record(stream=compute_stream) @@ -325,8 +326,7 @@ def _pobtas_streaming( L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) h2d_lower_diagonal_events[0].record(stream=h2d_stream) - - # --- Forward substitution --- + # --- Computations --- for i in range(0, n_diag_blocks - 1): if i + 1 < n_diag_blocks: @@ -487,7 +487,6 @@ def _pobtas_streaming( # Forward Pass # --- C: events + transfers --- - B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:], stream=h2d_stream) B_d[(n_diag_blocks - 1) % 2].set( @@ -499,9 +498,6 @@ def _pobtas_streaming( h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) - - - # ----- Backward substitution ----- if not partial: # X_{ndb+1} = L_{ndb+1,ndb+1}^{-T} (Y_{ndb+1}) @@ -530,11 +526,8 @@ def _pobtas_streaming( B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - #d2h_events[(n_diag_blocks - 1) % 2].record(stream=d2h_stream) - if n_diag_blocks > 1: - #h2d_stream.wait_event(d2h_events[(n_diag_blocks - 1) % 2]) B_d[n_diag_blocks % 2].set( arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], @@ -547,7 +540,7 @@ def _pobtas_streaming( h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) for i in range(n_diag_blocks - 2, -1, -1): - # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + if i > 0: h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) @@ -561,7 +554,8 @@ def _pobtas_streaming( with compute_stream: compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) - + + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 9b07c435..0420841c 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -222,6 +222,7 @@ def _pobts_streaming( compute_B_events[0].record(stream=compute_stream) + if n_diag_blocks > 1: B_d[1].set( @@ -234,7 +235,7 @@ def _pobts_streaming( h2d_events[0].record(stream=h2d_stream) for i in range(1, n_diag_blocks): - # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + if i + 1 < n_diag_blocks: h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) @@ -249,6 +250,7 @@ def _pobts_streaming( compute_stream.wait_event(h2d_events[(i + 1) % 2]) compute_stream.wait_event(d2h_events[(i + 1) % 2]) + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] @@ -306,7 +308,7 @@ def _pobts_streaming( h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) for i in range(n_diag_blocks - 2, -1, -1): - # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + if i > 0: h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) @@ -320,6 +322,7 @@ def _pobts_streaming( compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] From 9d3dda0ebddb6df88cc9b71609945b83c33bd5e9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 16:42:26 +0000 Subject: [PATCH 198/518] check for useless if --- src/serinv/algs/pobtas.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index a457528c..355aeae2 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -277,6 +277,7 @@ def _pobtas_streaming( L_arrow_tip_block_d = cp.empty_like(L_arrow_tip_block) if trans == "N": + # --- Forward substitution --- # delete helper variable del B_shape @@ -296,8 +297,6 @@ def _pobtas_streaming( compute_partial_events = [cp.cuda.Event(), cp.cuda.Event()] - # --- Forward substitution --- - # --- C: events + transfers --- compute_current_B_events[1].record(stream=compute_stream) compute_next_B_events[1].record(stream=compute_stream) @@ -329,16 +328,15 @@ def _pobtas_streaming( # --- Computations --- for i in range(0, n_diag_blocks - 1): - if i + 1 < n_diag_blocks: - # stream next B block - h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) - - B_d[(i + 1) % 2].set( - arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], - stream = h2d_stream - ) + #if i + 1 < n_diag_blocks: + # stream next B block + h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) + B_d[(i + 1) % 2].set( + arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream = h2d_stream + ) - h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) + h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) if i + 1 < n_diag_blocks - 1: # stream next diagonal block From c2427fe4df9f7c9c5f0aaff2c52275a5fa20f571 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 16:47:01 +0000 Subject: [PATCH 199/518] check for duplicate --- src/serinv/algs/pobtas.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 355aeae2..2056fc39 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -327,10 +327,9 @@ def _pobtas_streaming( # --- Computations --- for i in range(0, n_diag_blocks - 1): - - #if i + 1 < n_diag_blocks: - # stream next B block + # pass next B block h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) + B_d[(i + 1) % 2].set( arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream = h2d_stream @@ -339,7 +338,7 @@ def _pobtas_streaming( h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) if i + 1 < n_diag_blocks - 1: - # stream next diagonal block + # pass next diagonal block h2d_stream.wait_event(compute_current_B_events[(i + 1) % 2]) L_diagonal_blocks_d[(i + 1) % 2].set( @@ -351,7 +350,7 @@ def _pobtas_streaming( with compute_stream: - # Compute step 1 : compute B + # Solve current B compute_stream.wait_event(h2d_diagonal_events[i % 2]) B_d[i % 2] = cu_la.solve_triangular( @@ -362,7 +361,7 @@ def _pobtas_streaming( compute_current_B_events[i % 2].record(stream=compute_stream) - # stream B back + # Pass current B back d2h_stream.wait_event(compute_current_B_events[i % 2]) B_d[i % 2].get( @@ -374,7 +373,7 @@ def _pobtas_streaming( d2h_B_events[i % 2].record(stream=d2h_stream) if i + 1 < n_diag_blocks - 1: - # stream next lower diagonal block + # Pass next lower diagonal block h2d_stream.wait_event(compute_next_B_events[(i + 1) % 2]) L_lower_diagonal_blocks_d[(i + 1) % 2].set( @@ -385,7 +384,7 @@ def _pobtas_streaming( h2d_lower_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) with compute_stream: - # Compute step 2 : update next B + # Update next B compute_stream.wait_event(h2d_B_events[(i + 1) % 2]) B_d[(i + 1) % 2] -= ( @@ -396,7 +395,7 @@ def _pobtas_streaming( compute_next_B_events[i % 2].record(stream=compute_stream) if i + 1 < n_diag_blocks - 1: - # stream next lower arrow block + # Pass next lower arrow block h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) L_lower_arrow_blocks_d[(i + 1) % 2].set( @@ -407,7 +406,7 @@ def _pobtas_streaming( h2d_arrow_events[(i + 1) % 2].record(stream=h2d_stream) with compute_stream: - # Compute step 3 : update arrow tip + # Update arrow tip compute_stream.wait_event(h2d_arrow_events[i % 2]) B_arrow_tip_d -= ( @@ -417,11 +416,12 @@ def _pobtas_streaming( compute_arrow_B_events[i % 2].record(stream=compute_stream) - d2h_stream.wait_event(compute_arrow_B_events[i % 2]) + # Pass arrow tip back + d2h_stream.wait_event(compute_arrow_B_events[i % 2]) - B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - d2h_tip_events[i % 2].record(stream=d2h_stream) + d2h_tip_events[i % 2].record(stream=d2h_stream) if not partial: From 2f159bfb31a33a92f65cbb6852cef4ca7c6566d4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 16:48:11 +0000 Subject: [PATCH 200/518] reverted --- src/serinv/algs/pobtas.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 2056fc39..d85e140b 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -417,11 +417,11 @@ def _pobtas_streaming( compute_arrow_B_events[i % 2].record(stream=compute_stream) # Pass arrow tip back - d2h_stream.wait_event(compute_arrow_B_events[i % 2]) + d2h_stream.wait_event(compute_arrow_B_events[i % 2]) - B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - d2h_tip_events[i % 2].record(stream=d2h_stream) + d2h_tip_events[i % 2].record(stream=d2h_stream) if not partial: From d710e38cc73e20af98c297ebe754a0d16682e827 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 17:29:33 +0000 Subject: [PATCH 201/518] reduced for loop --- src/serinv/algs/pobtas.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index d85e140b..53311f06 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -417,11 +417,11 @@ def _pobtas_streaming( compute_arrow_B_events[i % 2].record(stream=compute_stream) # Pass arrow tip back - d2h_stream.wait_event(compute_arrow_B_events[i % 2]) + d2h_stream.wait_event(compute_arrow_B_events[n_diag_blocks % 2]) - B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - d2h_tip_events[i % 2].record(stream=d2h_stream) + d2h_tip_events[n_diag_blocks % 2].record(stream=d2h_stream) if not partial: From e3dc9d36ef93aa9c946fff57b9012c91668ac614 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 18:38:07 +0000 Subject: [PATCH 202/518] reordered streaming --- src/serinv/algs/pobtas.py | 39 +++++++++++++++++++++++++-------------- src/serinv/algs/pobts.py | 27 +++++++++++++++++---------- 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 53311f06..f5f52976 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -277,9 +277,9 @@ def _pobtas_streaming( L_arrow_tip_block_d = cp.empty_like(L_arrow_tip_block) if trans == "N": - # --- Forward substitution --- + # ----- Forward substitution ----- - # delete helper variable + # Delete helper variable del B_shape # Events @@ -350,7 +350,7 @@ def _pobtas_streaming( with compute_stream: - # Solve current B + # Solve current B block compute_stream.wait_event(h2d_diagonal_events[i % 2]) B_d[i % 2] = cu_la.solve_triangular( @@ -361,7 +361,7 @@ def _pobtas_streaming( compute_current_B_events[i % 2].record(stream=compute_stream) - # Pass current B back + # Pass current B block back d2h_stream.wait_event(compute_current_B_events[i % 2]) B_d[i % 2].get( @@ -384,7 +384,7 @@ def _pobtas_streaming( h2d_lower_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) with compute_stream: - # Update next B + # Update next B block compute_stream.wait_event(h2d_B_events[(i + 1) % 2]) B_d[(i + 1) % 2] -= ( @@ -416,7 +416,7 @@ def _pobtas_streaming( compute_arrow_B_events[i % 2].record(stream=compute_stream) - # Pass arrow tip back + # Pass arrow tip back d2h_stream.wait_event(compute_arrow_B_events[n_diag_blocks % 2]) B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) @@ -425,9 +425,7 @@ def _pobtas_streaming( if not partial: - # In the case of the partial solve, we do not solve the last block and - # arrow tip block of the RHS. - + # Pass last blocks h2d_stream.wait_event(d2h_tip_events[n_diag_blocks % 2]) L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream) @@ -440,15 +438,18 @@ def _pobtas_streaming( with compute_stream: - + # Solve last B block compute_stream.wait_event(h2d_diagonal_events[(n_diag_blocks - 1) % 2]) + B_d[(n_diag_blocks - 1) % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], B_d[(n_diag_blocks - 1) % 2], lower=True ) + compute_partial_events[0].record(stream=compute_stream) + # Pass last B block back d2h_stream.wait_event(compute_partial_events[0]) B_d[(n_diag_blocks - 1) % 2].get( @@ -460,6 +461,7 @@ def _pobtas_streaming( d2h_B_events[0].record(stream=d2h_stream) with compute_stream: + # Solve arrow tip compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) B_arrow_tip_d -= (L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2] @ B_d[(n_diag_blocks - 1) % 2]) @@ -472,10 +474,14 @@ def _pobtas_streaming( B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) elif trans == "T" or trans == "C": + # ----- Backward substitution ----- + # Buffers B_previous_d = cp.empty( (2, *B_shape.shape), dtype=B_shape.dtype ) + + # Delete helper variable del B_shape # Events @@ -483,8 +489,7 @@ def _pobtas_streaming( h2d_events = [cp.cuda.Event(), cp.cuda.Event()] d2h_events = [cp.cuda.Event(), cp.cuda.Event()] - # Forward Pass - # --- C: events + transfers --- + # --- H2D: transfers --- B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:], stream=h2d_stream) B_d[(n_diag_blocks - 1) % 2].set( @@ -498,8 +503,9 @@ def _pobtas_streaming( # ----- Backward substitution ----- if not partial: - # X_{ndb+1} = L_{ndb+1,ndb+1}^{-T} (Y_{ndb+1}) + with compute_stream: + # X_{ndb+1} = L_{ndb+1,ndb+1}^{-T} (Y_{ndb+1}) compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) B_arrow_tip_d = cu_la.solve_triangular( L_arrow_tip_block_d, @@ -508,6 +514,7 @@ def _pobtas_streaming( trans="C", ) + # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) B_previous_d[(n_diag_blocks - 1) % 2] = ( cu_la.solve_triangular( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], @@ -520,6 +527,7 @@ def _pobtas_streaming( compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) + # Pass arrow tip back d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) @@ -540,6 +548,7 @@ def _pobtas_streaming( for i in range(n_diag_blocks - 2, -1, -1): if i > 0: + # Pass new blocks h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) @@ -550,10 +559,10 @@ def _pobtas_streaming( h2d_events[i % 2].record(stream=h2d_stream) with compute_stream: + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) - # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] @@ -566,6 +575,7 @@ def _pobtas_streaming( compute_B_events[i % 2].record(compute_stream) + # Pass previous B block back d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) B_previous_d[(i - 1) % 2].get( @@ -576,6 +586,7 @@ def _pobtas_streaming( ) d2h_events[i % 2].record(stream=d2h_stream) + # Pass last B block back d2h_stream.wait_event(compute_B_events[0]) B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 0420841c..a532f892 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -204,12 +204,27 @@ def _pobts_streaming( d2h_events = [cp.cuda.Event(), cp.cuda.Event()] if trans == "N": + # ----- Forward substitution ----- + + # --- H2D: transfers --- B_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream) h2d_events[1].record(stream=h2d_stream) + if n_diag_blocks > 1: + + B_d[1].set( + arr=B[diag_blocksize : (2 * diag_blocksize)], + stream=h2d_stream + ) + L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) + L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) + + h2d_events[0].record(stream=h2d_stream) + with compute_stream: + # Solve first B block compute_stream.wait_event(h2d_events[1]) B_previous_d[0] = ( @@ -223,21 +238,13 @@ def _pobts_streaming( compute_B_events[0].record(stream=compute_stream) - if n_diag_blocks > 1: - - B_d[1].set( - arr=B[diag_blocksize : (2 * diag_blocksize)], - stream=h2d_stream - ) - L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) - L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) - - h2d_events[0].record(stream=h2d_stream) + for i in range(1, n_diag_blocks): if i + 1 < n_diag_blocks: + # Pass next blocks h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) B_d[(i + 1) % 2].set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream) From b8871ccb81e340989d2cd4fc54c6b5884f58a364 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 1 May 2025 18:44:35 +0000 Subject: [PATCH 203/518] moved streaming and added documentation --- src/serinv/algs/pobts.py | 41 ++++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index a532f892..295be756 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -213,7 +213,6 @@ def _pobts_streaming( h2d_events[1].record(stream=h2d_stream) if n_diag_blocks > 1: - B_d[1].set( arr=B[diag_blocksize : (2 * diag_blocksize)], stream=h2d_stream @@ -237,12 +236,8 @@ def _pobts_streaming( compute_B_events[0].record(stream=compute_stream) - - - for i in range(1, n_diag_blocks): - if i + 1 < n_diag_blocks: # Pass next blocks h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) @@ -254,10 +249,10 @@ def _pobts_streaming( h2d_events[i % 2].record(stream=h2d_stream) with compute_stream: + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} compute_stream.wait_event(h2d_events[(i + 1) % 2]) compute_stream.wait_event(d2h_events[(i + 1) % 2]) - # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] @@ -268,6 +263,7 @@ def _pobts_streaming( compute_B_events[i % 2].record(compute_stream) + # Pass previous B block back d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) B_previous_d[(i + 1) % 2].get( @@ -278,18 +274,34 @@ def _pobts_streaming( d2h_events[i % 2].record(stream=d2h_stream) + # Pass last B block back d2h_stream.wait_event(compute_B_events[(n_diag_blocks + 1) % 2]) B_previous_d[(n_diag_blocks + 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) elif trans == "T" or trans == "C": + # ----- Backward substitution ----- + + # --- H2D: transfers --- B_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[-1], stream=h2d_stream) h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + if n_diag_blocks > 1: + + B_d[n_diag_blocks % 2].set( + arr=B[-(2 * diag_blocksize) : -diag_blocksize], + stream=h2d_stream + ) + L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) + L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) + + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + with compute_stream: + # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) B_previous_d[(n_diag_blocks - 1) % 2] = ( @@ -303,20 +315,12 @@ def _pobts_streaming( compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) - if n_diag_blocks > 1: - - B_d[n_diag_blocks % 2].set( - arr=B[-(2 * diag_blocksize) : -diag_blocksize], - stream=h2d_stream - ) - L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) - L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) - - h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + for i in range(n_diag_blocks - 2, -1, -1): if i > 0: + # pass next blocks h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) @@ -326,10 +330,10 @@ def _pobts_streaming( h2d_events[i % 2].record(stream=h2d_stream) with compute_stream: + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) - # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] @@ -341,13 +345,14 @@ def _pobts_streaming( compute_B_events[i % 2].record(compute_stream) + # Pass previous B block back d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) B_previous_d[(i - 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) d2h_events[i % 2].record(stream=d2h_stream) - + # Pass last B block back d2h_stream.wait_event(compute_B_events[0]) B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) From 69c20abb211c9d1ef9be20d4166fb41b94dc07bd Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 2 May 2025 12:59:42 +0000 Subject: [PATCH 204/518] bigger tests --- tests/tests_algs/regular/conftest.py | 3 +++ tests/tests_algs/regular/tests_bt/test_pobts.py | 2 +- tests/tests_algs/regular/tests_bta/test_pobtas.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/tests_algs/regular/conftest.py b/tests/tests_algs/regular/conftest.py index 239baed2..db19b997 100644 --- a/tests/tests_algs/regular/conftest.py +++ b/tests/tests_algs/regular/conftest.py @@ -9,6 +9,9 @@ pytest.param(2, id="n_diag_blocks=2"), pytest.param(3, id="n_diag_blocks=3"), pytest.param(4, id="n_diag_blocks=4"), + pytest.param(4, id="n_diag_blocks=125"), + pytest.param(4, id="n_diag_blocks=500"), + pytest.param(4, id="n_diag_blocks=1000"), ] diff --git a/tests/tests_algs/regular/tests_bt/test_pobts.py b/tests/tests_algs/regular/tests_bt/test_pobts.py index f5c941dc..b8fb5ad3 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobts.py +++ b/tests/tests_algs/regular/tests_bt/test_pobts.py @@ -12,7 +12,7 @@ import cupyx as cpx @pytest.mark.mpi_skip() -@pytest.mark.parametrize("n_rhs", [1, 2, 3]) +@pytest.mark.parametrize("n_rhs", [1, 2, 3, 500, 2000, 4000]) def test_pobts( n_rhs: int, diagonal_blocksize: int, diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index e9ce2384..1c4a6c22 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -13,7 +13,7 @@ @pytest.mark.mpi_skip() -@pytest.mark.parametrize("n_rhs", [1, 2, 3]) +@pytest.mark.parametrize("n_rhs", [1, 2, 3, 500, 2000, 4000]) def test_pobtas( n_rhs: int, diagonal_blocksize: int, From f2570605711b8f1d343652ffd497c2098a63852b Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 2 May 2025 13:02:18 +0000 Subject: [PATCH 205/518] even bigger tests --- tests/tests_algs/regular/tests_bt/test_pobts.py | 2 +- tests/tests_algs/regular/tests_bta/test_pobtas.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_algs/regular/tests_bt/test_pobts.py b/tests/tests_algs/regular/tests_bt/test_pobts.py index b8fb5ad3..58168ab3 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobts.py +++ b/tests/tests_algs/regular/tests_bt/test_pobts.py @@ -12,7 +12,7 @@ import cupyx as cpx @pytest.mark.mpi_skip() -@pytest.mark.parametrize("n_rhs", [1, 2, 3, 500, 2000, 4000]) +@pytest.mark.parametrize("n_rhs", [1, 2, 3, 500, 2000, 4000, 8000, 16000]) def test_pobts( n_rhs: int, diagonal_blocksize: int, diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index 1c4a6c22..763a2679 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -13,7 +13,7 @@ @pytest.mark.mpi_skip() -@pytest.mark.parametrize("n_rhs", [1, 2, 3, 500, 2000, 4000]) +@pytest.mark.parametrize("n_rhs", [1, 2, 3, 500, 2000, 4000, 8000, 16000]) def test_pobtas( n_rhs: int, diagonal_blocksize: int, From 0416d8abba80ad7b742e96c7ec0b4fd8b3bdc992 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 2 May 2025 13:04:59 +0000 Subject: [PATCH 206/518] even more bigger tests --- tests/conftest.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 3e624933..f8ec3b56 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,9 @@ DIAGONAL_BLOCKSIZE = [ pytest.param(2, id="diagonal_blocksize=2"), pytest.param(3, id="diagonal_blocksize=3"), + pytest.param(500, id="diagonal_blocksize=500"), + pytest.param(500, id="diagonal_blocksize=1000"), + pytest.param(500, id="diagonal_blocksize=4000"), ] From b00d95b8e39b752872bf26d9fb4ddd80ca19b29b Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 2 May 2025 13:18:54 +0000 Subject: [PATCH 207/518] changed tests to be smaller --- tests/conftest.py | 2 -- tests/tests_algs/regular/conftest.py | 4 +--- tests/tests_algs/regular/tests_bt/test_pobts.py | 2 +- tests/tests_algs/regular/tests_bta/test_pobtas.py | 2 +- 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f8ec3b56..ac1a938a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,8 +25,6 @@ pytest.param(2, id="diagonal_blocksize=2"), pytest.param(3, id="diagonal_blocksize=3"), pytest.param(500, id="diagonal_blocksize=500"), - pytest.param(500, id="diagonal_blocksize=1000"), - pytest.param(500, id="diagonal_blocksize=4000"), ] diff --git a/tests/tests_algs/regular/conftest.py b/tests/tests_algs/regular/conftest.py index db19b997..0d22276e 100644 --- a/tests/tests_algs/regular/conftest.py +++ b/tests/tests_algs/regular/conftest.py @@ -9,9 +9,7 @@ pytest.param(2, id="n_diag_blocks=2"), pytest.param(3, id="n_diag_blocks=3"), pytest.param(4, id="n_diag_blocks=4"), - pytest.param(4, id="n_diag_blocks=125"), - pytest.param(4, id="n_diag_blocks=500"), - pytest.param(4, id="n_diag_blocks=1000"), + pytest.param(125, id="n_diag_blocks=125"), ] diff --git a/tests/tests_algs/regular/tests_bt/test_pobts.py b/tests/tests_algs/regular/tests_bt/test_pobts.py index 58168ab3..0cff67fe 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobts.py +++ b/tests/tests_algs/regular/tests_bt/test_pobts.py @@ -12,7 +12,7 @@ import cupyx as cpx @pytest.mark.mpi_skip() -@pytest.mark.parametrize("n_rhs", [1, 2, 3, 500, 2000, 4000, 8000, 16000]) +@pytest.mark.parametrize("n_rhs", [1, 2, 3, 500, 2000]) def test_pobts( n_rhs: int, diagonal_blocksize: int, diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index 763a2679..288af233 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -13,7 +13,7 @@ @pytest.mark.mpi_skip() -@pytest.mark.parametrize("n_rhs", [1, 2, 3, 500, 2000, 4000, 8000, 16000]) +@pytest.mark.parametrize("n_rhs", [1, 2, 3, 500, 2000]) def test_pobtas( n_rhs: int, diagonal_blocksize: int, From 6dc83e6d58e3aaae7ba5ad12505b9be1cc2e9492 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 2 May 2025 13:21:29 +0000 Subject: [PATCH 208/518] smaller tests again --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index ac1a938a..5f6d4827 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,7 @@ DIAGONAL_BLOCKSIZE = [ pytest.param(2, id="diagonal_blocksize=2"), pytest.param(3, id="diagonal_blocksize=3"), - pytest.param(500, id="diagonal_blocksize=500"), + pytest.param(20, id="diagonal_blocksize=20"), ] From 2645a8fa5484142cde799953d284c6e4b7d5d2d4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 2 May 2025 13:23:25 +0000 Subject: [PATCH 209/518] reset tests --- tests/conftest.py | 1 - tests/tests_algs/regular/conftest.py | 1 - tests/tests_algs/regular/tests_bt/test_pobts.py | 2 +- tests/tests_algs/regular/tests_bta/test_pobtas.py | 2 +- 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 5f6d4827..3e624933 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,6 @@ DIAGONAL_BLOCKSIZE = [ pytest.param(2, id="diagonal_blocksize=2"), pytest.param(3, id="diagonal_blocksize=3"), - pytest.param(20, id="diagonal_blocksize=20"), ] diff --git a/tests/tests_algs/regular/conftest.py b/tests/tests_algs/regular/conftest.py index 0d22276e..239baed2 100644 --- a/tests/tests_algs/regular/conftest.py +++ b/tests/tests_algs/regular/conftest.py @@ -9,7 +9,6 @@ pytest.param(2, id="n_diag_blocks=2"), pytest.param(3, id="n_diag_blocks=3"), pytest.param(4, id="n_diag_blocks=4"), - pytest.param(125, id="n_diag_blocks=125"), ] diff --git a/tests/tests_algs/regular/tests_bt/test_pobts.py b/tests/tests_algs/regular/tests_bt/test_pobts.py index 0cff67fe..f5c941dc 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobts.py +++ b/tests/tests_algs/regular/tests_bt/test_pobts.py @@ -12,7 +12,7 @@ import cupyx as cpx @pytest.mark.mpi_skip() -@pytest.mark.parametrize("n_rhs", [1, 2, 3, 500, 2000]) +@pytest.mark.parametrize("n_rhs", [1, 2, 3]) def test_pobts( n_rhs: int, diagonal_blocksize: int, diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index 288af233..e9ce2384 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -13,7 +13,7 @@ @pytest.mark.mpi_skip() -@pytest.mark.parametrize("n_rhs", [1, 2, 3, 500, 2000]) +@pytest.mark.parametrize("n_rhs", [1, 2, 3]) def test_pobtas( n_rhs: int, diagonal_blocksize: int, From 7329ec3ada7580fb1a599e13a4c7d74c2d1ecfd6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 8 May 2025 13:59:55 +0200 Subject: [PATCH 210/518] add scripts for cscs --- run_streamlined_sequential_pobtax_gpu.sh | 51 +++++ ...d_sequential_pobtax_gpu.sh:Zone.Identifier | 3 + .../streamlined_sequential_pobtax_gpu.py | 195 ++++++++++++++++++ ...d_sequential_pobtax_gpu.py:Zone.Identifier | 3 + 4 files changed, 252 insertions(+) create mode 100644 run_streamlined_sequential_pobtax_gpu.sh create mode 100644 sc25_runs/positive_definite/run_streamlined_sequential_pobtax_gpu.sh:Zone.Identifier create mode 100644 sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py create mode 100644 sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py:Zone.Identifier diff --git a/run_streamlined_sequential_pobtax_gpu.sh b/run_streamlined_sequential_pobtax_gpu.sh new file mode 100644 index 00000000..fe03f6bf --- /dev/null +++ b/run_streamlined_sequential_pobtax_gpu.sh @@ -0,0 +1,51 @@ +#!/bin/bash -l +#SBATCH --job-name="serinv_pobtx_benchmark" +#SBATCH --output=%x.%j.out +#SBATCH --error=%x.%j.err +#SBATCH --account=lp82 +#SBATCH --time=00:10:00 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --gpus-per-task=1 +#SBATCH --partition=debug +#SBATCH --constraint=gpu +#SBATCH --hint=nomultithread +#SBATCH --uenv=prgenv-gnu/24.11:v1 +#SBATCH --view=modules + +set -e -u + +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +export MPICH_GPU_SUPPORT_ENABLED=1 +export OMP_PLACES=cores +export OMP_PROC_BIND=close + +export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID + +source ~/load_modules.sh +conda activate serinv_env + +# Dataset 1: b = 1675, a = 6, n = 128 +# Reference timings (to beat!): +# - pobtaf: 0.38959 +# - pobtas: 0.02415 +# - pobtasi: 0.29593 +# export b=1675 +# export a=6 +# export n=128 + +# Dataset 1: b = 4002, a = 6, n = 250 +# Reference timings (to beat!): +# - pobtaf: 3.2716 (INLA_BTA CUDA code: 2.713) +# - pobtas: 0.15397 +# - pobtasi: 5.15729 +export b=4002 +export a=6 +export n=250 + +# Benchmark the code +srun python ~/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py --b $b --a $a --n $n + +# Profile the code +# srun nsys profile --force-overwrite=true -o profile_serinv_pobtax_b${b}_a${a}_n${n} python ~/repositories/serinv/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py --b $b --a $a --n $n --b $b --a $a --n $n \ No newline at end of file diff --git a/sc25_runs/positive_definite/run_streamlined_sequential_pobtax_gpu.sh:Zone.Identifier b/sc25_runs/positive_definite/run_streamlined_sequential_pobtax_gpu.sh:Zone.Identifier new file mode 100644 index 00000000..33e02d64 --- /dev/null +++ b/sc25_runs/positive_definite/run_streamlined_sequential_pobtax_gpu.sh:Zone.Identifier @@ -0,0 +1,3 @@ +[ZoneTransfer] +ZoneId=3 +HostUrl=https://iis-mattermost.ee.ethz.ch/api/v4/files/waiggpk1miyeb84dcahdh53b1e?download=1 diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py new file mode 100644 index 00000000..88f549cd --- /dev/null +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -0,0 +1,195 @@ +import time + +tic = time.perf_counter() +import argparse + +import numpy as np +import cupy as cp +from cupy.cuda.nvtx import RangePush, RangePop + +from serinv.algs import pobtaf, pobtas, pobtasi + + +def sequential_dataset( + n_blocks: int, + diagonal_blocksize: int, + arrowhead_blocksize: int, +): + A_diagonal_blocks = np.random.rand(n_blocks, diagonal_blocksize, diagonal_blocksize) + A_lower_diagonal_blocks = np.random.rand( + n_blocks - 1, diagonal_blocksize, diagonal_blocksize + ) + A_arrow_bottom_blocks = np.random.rand( + n_blocks, arrowhead_blocksize, diagonal_blocksize + ) + A_arrow_tip_block = np.random.rand(arrowhead_blocksize, arrowhead_blocksize) + + # CODE TO MODIFY + arrow_colsum = np.zeros((arrowhead_blocksize), dtype=A_diagonal_blocks.dtype) + for i in range(A_diagonal_blocks.shape[0]): + colsum = np.sum(A_diagonal_blocks[i, :, :], axis=1) - np.diag( + A_diagonal_blocks[i, :, :] + ) + if i > 0: + colsum += np.sum(A_lower_diagonal_blocks[i - 1, :, :], axis=1) + + A_diagonal_blocks[i, :, :] += np.diag(colsum) + + arrow_colsum[:] += np.sum(A_arrow_bottom_blocks[i, :, :], axis=1) + + A_arrow_tip_block[:, :] += np.diag( + arrow_colsum + np.sum(A_arrow_tip_block[:, :], axis=1) + ) + + return ( + A_diagonal_blocks, + A_lower_diagonal_blocks, + A_arrow_bottom_blocks, + A_arrow_tip_block, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process some integers.") + parser.add_argument( + "--b", + type=int, + default=128, + help="an integer for the diagonal block size", + ) + parser.add_argument( + "--a", + type=int, + default=0, + help="an integer for the diagonal block size", + ) + parser.add_argument( + "--n", + type=int, + default=8, + help="an integer for the number of diagonal blocks", + ) + args = parser.parse_args() + toc = time.perf_counter() + print(f"Import and parsing took: {toc - tic:.5f} sec", flush=True) + + diagonal_blocksize = args.b + arrowhead_blocksize = args.a + n_blocks = args.n + n_iterations = 10 + n_warmups = 2 + + tic = time.perf_counter() + ( + A_diagonal_blocks_cpu, + A_lower_diagonal_blocks_cpu, + A_arrow_bottom_blocks_cpu, + A_arrow_tip_block_cpu, + ) = sequential_dataset( + n_blocks, + diagonal_blocksize, + arrowhead_blocksize, + ) + B_cpu = np.random.rand(diagonal_blocksize * n_blocks + arrowhead_blocksize, 1) + toc = time.perf_counter() + print(f"Generate dataset took: {toc - tic:.5f} sec", flush=True) + print(f" b = {diagonal_blocksize}", flush=True) + print(f" a = {arrowhead_blocksize}", flush=True) + print(f" n = {n_blocks}", flush=True) + print(f" n_iterations = {n_iterations}", flush=True) + print(f" n_warmups = {n_warmups}", flush=True) + + total_memory = ( + A_diagonal_blocks_cpu.nbytes + + A_lower_diagonal_blocks_cpu.nbytes + + A_arrow_bottom_blocks_cpu.nbytes + + A_arrow_tip_block_cpu.nbytes + + B_cpu.nbytes + ) + print(f" Total memory: {total_memory / 1e9:.5f} GB", flush=True) + + tic = time.perf_counter() + # Init device arrays + A_diagonal_blocks_gpu = cp.empty_like(A_diagonal_blocks_cpu) + A_lower_diagonal_blocks_gpu = cp.empty_like(A_lower_diagonal_blocks_cpu) + A_arrow_bottom_blocks_gpu = cp.empty_like(A_arrow_bottom_blocks_cpu) + A_arrow_tip_block_gpu = cp.empty_like(A_arrow_tip_block_cpu) + B_gpu = cp.empty_like(B_cpu) + toc = time.perf_counter() + print(f"Init device arrays took: {toc - tic:.5f} sec", flush=True) + + t_pobtaf = [] + t_pobtas = [] + t_pobtasi = [] + + for i in range(n_warmups + n_iterations): + print(f"Iteration: {i+1}/{n_warmups+n_iterations}", flush=True) + + tic = time.perf_counter() + A_diagonal_blocks_gpu.set(arr=A_diagonal_blocks_cpu) + A_lower_diagonal_blocks_gpu.set(arr=A_lower_diagonal_blocks_cpu) + A_arrow_bottom_blocks_gpu.set(arr=A_arrow_bottom_blocks_cpu) + A_arrow_tip_block_gpu.set(arr=A_arrow_tip_block_cpu) + B_gpu.set(arr=B_cpu) + toc = time.perf_counter() + print(f"Copying data to GPU took: {toc - tic:.5f} sec", flush=True) + + cp.cuda.runtime.deviceSynchronize() + RangePush(f"pobtaf: i:{i}") + tic = time.perf_counter() + pobtaf( + A_diagonal_blocks_gpu, + A_lower_diagonal_blocks_gpu, + A_arrow_bottom_blocks_gpu, + A_arrow_tip_block_gpu, + ) + cp.cuda.runtime.deviceSynchronize() + toc = time.perf_counter() + RangePop() + elapsed = toc - tic + print(f"pobtaf took: {elapsed:.5f} sec", flush=True) + if i >= n_warmups: + t_pobtaf.append(elapsed) + + cp.cuda.runtime.deviceSynchronize() + RangePush(f"pobtas: i:{i}") + tic = time.perf_counter() + pobtas( + A_diagonal_blocks_gpu, + A_lower_diagonal_blocks_gpu, + A_arrow_bottom_blocks_gpu, + A_arrow_tip_block_gpu, + B_gpu, + ) + cp.cuda.runtime.deviceSynchronize() + toc = time.perf_counter() + RangePop() + elapsed = toc - tic + print(f"pobtas took: {elapsed:.5f} sec", flush=True) + if i >= n_warmups: + t_pobtas.append(elapsed) + + cp.cuda.runtime.deviceSynchronize() + RangePush(f"pobtasi: i:{i}") + tic = time.perf_counter() + pobtasi( + A_diagonal_blocks_gpu, + A_lower_diagonal_blocks_gpu, + A_arrow_bottom_blocks_gpu, + A_arrow_tip_block_gpu, + ) + cp.cuda.runtime.deviceSynchronize() + toc = time.perf_counter() + RangePop() + elapsed = toc - tic + print(f"pobtasi took: {elapsed:.5f} sec", flush=True) + if i >= n_warmups: + t_pobtasi.append(elapsed) + + print(f"t_pobtaf: {t_pobtaf}", flush=True) + print(f"t_pobtas: {t_pobtas}", flush=True) + print(f"t_pobtasi: {t_pobtasi}", flush=True) + + print(f"avg t_pobtaf: {np.mean(np.array(t_pobtaf)):.5f} sec", flush=True) + print(f"avg t_pobtas: {np.mean(np.array(t_pobtas)):.5f} sec", flush=True) + print(f"avg t_pobtasi: {np.mean(np.array(t_pobtasi)):.5f} sec", flush=True) \ No newline at end of file diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py:Zone.Identifier b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py:Zone.Identifier new file mode 100644 index 00000000..ce8dec59 --- /dev/null +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py:Zone.Identifier @@ -0,0 +1,3 @@ +[ZoneTransfer] +ZoneId=3 +HostUrl=https://iis-mattermost.ee.ethz.ch/api/v4/files/fw5m5tapefbi8deseto5qqro9w?download=1 From 04cbd7779b160e6031ee0613df0c8812fbce2e34 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 8 May 2025 14:34:20 +0000 Subject: [PATCH 211/518] updarte bash script --- run_streamlined_sequential_pobtax_gpu.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/run_streamlined_sequential_pobtax_gpu.sh b/run_streamlined_sequential_pobtax_gpu.sh index fe03f6bf..5d13342c 100644 --- a/run_streamlined_sequential_pobtax_gpu.sh +++ b/run_streamlined_sequential_pobtax_gpu.sh @@ -2,7 +2,7 @@ #SBATCH --job-name="serinv_pobtx_benchmark" #SBATCH --output=%x.%j.out #SBATCH --error=%x.%j.err -#SBATCH --account=lp82 +#SBATCH --account=lp16 #SBATCH --time=00:10:00 #SBATCH --nodes=1 #SBATCH --ntasks-per-node=1 @@ -35,7 +35,7 @@ conda activate serinv_env # export a=6 # export n=128 -# Dataset 1: b = 4002, a = 6, n = 250 +# Dataset 2: b = 4002, a = 6, n = 250 # Reference timings (to beat!): # - pobtaf: 3.2716 (INLA_BTA CUDA code: 2.713) # - pobtas: 0.15397 From 6a85bc5987127889e341a69bd7488382b365c002 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 9 May 2025 07:36:33 +0000 Subject: [PATCH 212/518] removed load_modules --- run_streamlined_sequential_pobtax_gpu.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_streamlined_sequential_pobtax_gpu.sh b/run_streamlined_sequential_pobtax_gpu.sh index 5d13342c..33d81986 100644 --- a/run_streamlined_sequential_pobtax_gpu.sh +++ b/run_streamlined_sequential_pobtax_gpu.sh @@ -23,7 +23,7 @@ export OMP_PROC_BIND=close export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID -source ~/load_modules.sh +# source ~/load_modules.sh conda activate serinv_env # Dataset 1: b = 1675, a = 6, n = 128 From d10b43692633929fccd304b65d17e643413d50e4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 9 May 2025 08:19:11 +0000 Subject: [PATCH 213/518] changed file path --- run_streamlined_sequential_pobtax_gpu.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_streamlined_sequential_pobtax_gpu.sh b/run_streamlined_sequential_pobtax_gpu.sh index 33d81986..74b5d9db 100644 --- a/run_streamlined_sequential_pobtax_gpu.sh +++ b/run_streamlined_sequential_pobtax_gpu.sh @@ -45,7 +45,7 @@ export a=6 export n=250 # Benchmark the code -srun python ~/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py --b $b --a $a --n $n +srun python ~/serinv/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py --b $b --a $a --n $n # Profile the code # srun nsys profile --force-overwrite=true -o profile_serinv_pobtax_b${b}_a${a}_n${n} python ~/repositories/serinv/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py --b $b --a $a --n $n --b $b --a $a --n $n \ No newline at end of file From 1a64890a42104fdaba7abff1721a75b11ee54f3b Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 9 May 2025 08:42:06 +0000 Subject: [PATCH 214/518] change to enable streaming on daint --- sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index 88f549cd..a6f6be10 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -160,6 +160,7 @@ def sequential_dataset( A_arrow_bottom_blocks_gpu, A_arrow_tip_block_gpu, B_gpu, + device_streaming=True ) cp.cuda.runtime.deviceSynchronize() toc = time.perf_counter() From d21064f4b58258aba42bc8f4e7ddbaa9c83f0f6b Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 07:44:49 +0000 Subject: [PATCH 215/518] added check message --- src/serinv/algs/pobtas.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index f5f52976..a6a528f6 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -238,6 +238,8 @@ def _pobtas_streaming( raise NotImplementedError( "Host<->Device streaming only works when host-arrays are given." ) + + print("streaming") cp, cu_la = _get_module_from_str(module_str="cupy") From 86ce7c185ac035ef3f3ffd9e6e9a067788d8a57b Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 08:38:42 +0000 Subject: [PATCH 216/518] changed given arrays --- .../streamlined_sequential_pobtax_gpu.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index a6f6be10..9002afaf 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -138,10 +138,10 @@ def sequential_dataset( RangePush(f"pobtaf: i:{i}") tic = time.perf_counter() pobtaf( - A_diagonal_blocks_gpu, - A_lower_diagonal_blocks_gpu, - A_arrow_bottom_blocks_gpu, - A_arrow_tip_block_gpu, + A_diagonal_blocks_cpu, + A_lower_diagonal_blocks_cpu, + A_arrow_bottom_blocks_cpu, + A_arrow_tip_block_cpu, ) cp.cuda.runtime.deviceSynchronize() toc = time.perf_counter() @@ -155,11 +155,11 @@ def sequential_dataset( RangePush(f"pobtas: i:{i}") tic = time.perf_counter() pobtas( - A_diagonal_blocks_gpu, - A_lower_diagonal_blocks_gpu, - A_arrow_bottom_blocks_gpu, - A_arrow_tip_block_gpu, - B_gpu, + A_diagonal_blocks_cpu, + A_lower_diagonal_blocks_cpu, + A_arrow_bottom_blocks_cpu, + A_arrow_tip_block_cpu, + B_cpu, device_streaming=True ) cp.cuda.runtime.deviceSynchronize() @@ -174,10 +174,10 @@ def sequential_dataset( RangePush(f"pobtasi: i:{i}") tic = time.perf_counter() pobtasi( - A_diagonal_blocks_gpu, - A_lower_diagonal_blocks_gpu, - A_arrow_bottom_blocks_gpu, - A_arrow_tip_block_gpu, + A_diagonal_blocks_cpu, + A_lower_diagonal_blocks_cpu, + A_arrow_bottom_blocks_cpu, + A_arrow_tip_block_cpu, ) cp.cuda.runtime.deviceSynchronize() toc = time.perf_counter() From 942a1465c33d9470f4a4a4539583f89f3f73ea18 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 09:00:23 +0000 Subject: [PATCH 217/518] rolled back block choice for further testing --- .../streamlined_sequential_pobtax_gpu.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index 9002afaf..8d575c17 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -138,10 +138,11 @@ def sequential_dataset( RangePush(f"pobtaf: i:{i}") tic = time.perf_counter() pobtaf( - A_diagonal_blocks_cpu, - A_lower_diagonal_blocks_cpu, - A_arrow_bottom_blocks_cpu, - A_arrow_tip_block_cpu, + A_diagonal_blocks_gpu, + A_lower_diagonal_blocks_gpu, + A_arrow_bottom_blocks_gpu, + A_arrow_tip_block_gpu, + device_streaming=True ) cp.cuda.runtime.deviceSynchronize() toc = time.perf_counter() @@ -155,11 +156,11 @@ def sequential_dataset( RangePush(f"pobtas: i:{i}") tic = time.perf_counter() pobtas( - A_diagonal_blocks_cpu, - A_lower_diagonal_blocks_cpu, - A_arrow_bottom_blocks_cpu, - A_arrow_tip_block_cpu, - B_cpu, + A_diagonal_blocks_gpu, + A_lower_diagonal_blocks_gpu, + A_arrow_bottom_blocks_gpu, + A_arrow_tip_block_gpu, + B_gpu, device_streaming=True ) cp.cuda.runtime.deviceSynchronize() @@ -174,10 +175,10 @@ def sequential_dataset( RangePush(f"pobtasi: i:{i}") tic = time.perf_counter() pobtasi( - A_diagonal_blocks_cpu, - A_lower_diagonal_blocks_cpu, - A_arrow_bottom_blocks_cpu, - A_arrow_tip_block_cpu, + A_diagonal_blocks_gpu, + A_lower_diagonal_blocks_gpu, + A_arrow_bottom_blocks_gpu, + A_arrow_tip_block_gpu, ) cp.cuda.runtime.deviceSynchronize() toc = time.perf_counter() From dcc85fb0a85111472bdbd08aa0d36944fa8792b5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 09:06:46 +0000 Subject: [PATCH 218/518] attempt to activate streaming --- .../streamlined_sequential_pobtax_gpu.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index 8d575c17..98206bd8 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -142,7 +142,7 @@ def sequential_dataset( A_lower_diagonal_blocks_gpu, A_arrow_bottom_blocks_gpu, A_arrow_tip_block_gpu, - device_streaming=True + # device_streaming=True ) cp.cuda.runtime.deviceSynchronize() toc = time.perf_counter() @@ -152,15 +152,24 @@ def sequential_dataset( if i >= n_warmups: t_pobtaf.append(elapsed) + tic = time.perf_counter() + A_diagonal_blocks_gpu.get(arr=A_diagonal_blocks_cpu) + A_lower_diagonal_blocks_gpu.get(arr=A_lower_diagonal_blocks_cpu) + A_arrow_bottom_blocks_gpu.get(arr=A_arrow_bottom_blocks_cpu) + A_arrow_tip_block_gpu.get(arr=A_arrow_tip_block_cpu) + B_gpu.get(arr=B_cpu) + toc = time.perf_counter() + print(f"Copying data from GPU took: {toc - tic:.5f} sec", flush=True) + cp.cuda.runtime.deviceSynchronize() RangePush(f"pobtas: i:{i}") tic = time.perf_counter() pobtas( - A_diagonal_blocks_gpu, - A_lower_diagonal_blocks_gpu, - A_arrow_bottom_blocks_gpu, - A_arrow_tip_block_gpu, - B_gpu, + A_diagonal_blocks_cpu, + A_lower_diagonal_blocks_cpu, + A_arrow_bottom_blocks_cpu, + A_arrow_tip_block_cpu, + B_cpu, device_streaming=True ) cp.cuda.runtime.deviceSynchronize() @@ -171,6 +180,15 @@ def sequential_dataset( if i >= n_warmups: t_pobtas.append(elapsed) + tic = time.perf_counter() + A_diagonal_blocks_gpu.set(arr=A_diagonal_blocks_cpu) + A_lower_diagonal_blocks_gpu.set(arr=A_lower_diagonal_blocks_cpu) + A_arrow_bottom_blocks_gpu.set(arr=A_arrow_bottom_blocks_cpu) + A_arrow_tip_block_gpu.set(arr=A_arrow_tip_block_cpu) + B_gpu.set(arr=B_cpu) + toc = time.perf_counter() + print(f"Copying data to GPU took: {toc - tic:.5f} sec", flush=True) + cp.cuda.runtime.deviceSynchronize() RangePush(f"pobtasi: i:{i}") tic = time.perf_counter() From a73b7b900dfeaad694041a7a62be745366bebeb3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 09:09:32 +0000 Subject: [PATCH 219/518] typo --- .../streamlined_sequential_pobtax_gpu.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index 98206bd8..c758bf0f 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -153,10 +153,10 @@ def sequential_dataset( t_pobtaf.append(elapsed) tic = time.perf_counter() - A_diagonal_blocks_gpu.get(arr=A_diagonal_blocks_cpu) - A_lower_diagonal_blocks_gpu.get(arr=A_lower_diagonal_blocks_cpu) - A_arrow_bottom_blocks_gpu.get(arr=A_arrow_bottom_blocks_cpu) - A_arrow_tip_block_gpu.get(arr=A_arrow_tip_block_cpu) + A_diagonal_blocks_gpu.get(out=A_diagonal_blocks_cpu) + A_lower_diagonal_blocks_gpu.get(out=A_lower_diagonal_blocks_cpu) + A_arrow_bottom_blocks_gpu.get(out=A_arrow_bottom_blocks_cpu) + A_arrow_tip_block_gpu.get(out=A_arrow_tip_block_cpu) B_gpu.get(arr=B_cpu) toc = time.perf_counter() print(f"Copying data from GPU took: {toc - tic:.5f} sec", flush=True) From 408628e717420f35090199cf45cbe5cf05bc8445 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 09:32:12 +0000 Subject: [PATCH 220/518] another typo --- .../positive_definite/streamlined_sequential_pobtax_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index c758bf0f..7e65e662 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -157,7 +157,7 @@ def sequential_dataset( A_lower_diagonal_blocks_gpu.get(out=A_lower_diagonal_blocks_cpu) A_arrow_bottom_blocks_gpu.get(out=A_arrow_bottom_blocks_cpu) A_arrow_tip_block_gpu.get(out=A_arrow_tip_block_cpu) - B_gpu.get(arr=B_cpu) + B_gpu.get(out=B_cpu) toc = time.perf_counter() print(f"Copying data from GPU took: {toc - tic:.5f} sec", flush=True) From 060fd0b1c6aa7c06e28b94e7ae8504f48207dec7 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 09:53:12 +0000 Subject: [PATCH 221/518] enable streaming for pobtaf --- .../streamlined_sequential_pobtax_gpu.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index 7e65e662..a35d6a2c 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -138,11 +138,11 @@ def sequential_dataset( RangePush(f"pobtaf: i:{i}") tic = time.perf_counter() pobtaf( - A_diagonal_blocks_gpu, - A_lower_diagonal_blocks_gpu, - A_arrow_bottom_blocks_gpu, - A_arrow_tip_block_gpu, - # device_streaming=True + A_diagonal_blocks_cpu, + A_lower_diagonal_blocks_cpu, + A_arrow_bottom_blocks_cpu, + A_arrow_tip_block_cpu, + device_streaming=True ) cp.cuda.runtime.deviceSynchronize() toc = time.perf_counter() From d08b7b3924811db11e432cc7209499efd4c740fc Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 09:54:05 +0000 Subject: [PATCH 222/518] removing copy --- .../streamlined_sequential_pobtax_gpu.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index a35d6a2c..1b879ed3 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -152,14 +152,14 @@ def sequential_dataset( if i >= n_warmups: t_pobtaf.append(elapsed) - tic = time.perf_counter() - A_diagonal_blocks_gpu.get(out=A_diagonal_blocks_cpu) - A_lower_diagonal_blocks_gpu.get(out=A_lower_diagonal_blocks_cpu) - A_arrow_bottom_blocks_gpu.get(out=A_arrow_bottom_blocks_cpu) - A_arrow_tip_block_gpu.get(out=A_arrow_tip_block_cpu) - B_gpu.get(out=B_cpu) - toc = time.perf_counter() - print(f"Copying data from GPU took: {toc - tic:.5f} sec", flush=True) + #tic = time.perf_counter() + #A_diagonal_blocks_gpu.get(out=A_diagonal_blocks_cpu) + #A_lower_diagonal_blocks_gpu.get(out=A_lower_diagonal_blocks_cpu) + #A_arrow_bottom_blocks_gpu.get(out=A_arrow_bottom_blocks_cpu) + #A_arrow_tip_block_gpu.get(out=A_arrow_tip_block_cpu) + #B_gpu.get(out=B_cpu) + #toc = time.perf_counter() + #print(f"Copying data from GPU took: {toc - tic:.5f} sec", flush=True) cp.cuda.runtime.deviceSynchronize() RangePush(f"pobtas: i:{i}") From ac4779950ad5886bce58f77174d252800994b1e6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 11:33:46 +0000 Subject: [PATCH 223/518] pinned memory --- .../streamlined_sequential_pobtax_gpu.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index 1b879ed3..35d2ce5a 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -118,6 +118,23 @@ def sequential_dataset( toc = time.perf_counter() print(f"Init device arrays took: {toc - tic:.5f} sec", flush=True) + A_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_diagonal_blocks_cpu) + A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks_cpu[:, :, :] + A_lower_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_lower_diagonal_blocks_cpu) + A_lower_diagonal_blocks_pinned[:, :, :] = A_lower_diagonal_blocks_cpu[:, :, :] + A_lower_arrow_blocks_pinned = cpx.zeros_like_pinned(A_lower_arrow_blocks_cpu) + A_lower_arrow_blocks_pinned[:, :, :] = A_lower_arrow_blocks_cpu[:, :, :] + A_arrow_tip_block_pinned = cpx.zeros_like_pinned(A_arrow_tip_block_cpu) + A_arrow_tip_block_pinned[:, :] = A_arrow_tip_block_cpu[:, :] + B_pinned = cpx.zeros_like_pinned(B_cpu) + B_pinned[:, :] = B_cpu[:, :] + + A_diagonal_blocks_cpu = A_diagonal_blocks_pinned + A_lower_diagonal_blocks_cpu = A_lower_diagonal_blocks_pinned + A_lower_arrow_blocks_cpu = A_lower_arrow_blocks_pinned + A_arrow_tip_block_cpu = A_arrow_tip_block_pinned + B = B_pinned + t_pobtaf = [] t_pobtas = [] t_pobtasi = [] From 5faecf625929a7f0ed6071427c6e9b8de650fcf3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 11:43:12 +0000 Subject: [PATCH 224/518] typo --- .../streamlined_sequential_pobtax_gpu.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index 35d2ce5a..49537dfa 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -118,15 +118,15 @@ def sequential_dataset( toc = time.perf_counter() print(f"Init device arrays took: {toc - tic:.5f} sec", flush=True) - A_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_diagonal_blocks_cpu) + A_diagonal_blocks_pinned = cp.zeros_like_pinned(A_diagonal_blocks_cpu) A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks_cpu[:, :, :] - A_lower_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_lower_diagonal_blocks_cpu) + A_lower_diagonal_blocks_pinned = cp.zeros_like_pinned(A_lower_diagonal_blocks_cpu) A_lower_diagonal_blocks_pinned[:, :, :] = A_lower_diagonal_blocks_cpu[:, :, :] - A_lower_arrow_blocks_pinned = cpx.zeros_like_pinned(A_lower_arrow_blocks_cpu) + A_lower_arrow_blocks_pinned = cp.zeros_like_pinned(A_lower_arrow_blocks_cpu) A_lower_arrow_blocks_pinned[:, :, :] = A_lower_arrow_blocks_cpu[:, :, :] - A_arrow_tip_block_pinned = cpx.zeros_like_pinned(A_arrow_tip_block_cpu) + A_arrow_tip_block_pinned = cp.zeros_like_pinned(A_arrow_tip_block_cpu) A_arrow_tip_block_pinned[:, :] = A_arrow_tip_block_cpu[:, :] - B_pinned = cpx.zeros_like_pinned(B_cpu) + B_pinned = cp.zeros_like_pinned(B_cpu) B_pinned[:, :] = B_cpu[:, :] A_diagonal_blocks_cpu = A_diagonal_blocks_pinned From e668184223285e83aee1d3d293ad8fc7dc25948c Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 11:44:16 +0000 Subject: [PATCH 225/518] changed block name --- .../positive_definite/streamlined_sequential_pobtax_gpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index 49537dfa..5aef4518 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -122,8 +122,8 @@ def sequential_dataset( A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks_cpu[:, :, :] A_lower_diagonal_blocks_pinned = cp.zeros_like_pinned(A_lower_diagonal_blocks_cpu) A_lower_diagonal_blocks_pinned[:, :, :] = A_lower_diagonal_blocks_cpu[:, :, :] - A_lower_arrow_blocks_pinned = cp.zeros_like_pinned(A_lower_arrow_blocks_cpu) - A_lower_arrow_blocks_pinned[:, :, :] = A_lower_arrow_blocks_cpu[:, :, :] + A_lower_arrow_blocks_pinned = cp.zeros_like_pinned(A_arrow_bottom_blocks_cpu) + A_lower_arrow_blocks_pinned[:, :, :] = A_arrow_bottom_blocks_cpu[:, :, :] A_arrow_tip_block_pinned = cp.zeros_like_pinned(A_arrow_tip_block_cpu) A_arrow_tip_block_pinned[:, :] = A_arrow_tip_block_cpu[:, :] B_pinned = cp.zeros_like_pinned(B_cpu) @@ -131,7 +131,7 @@ def sequential_dataset( A_diagonal_blocks_cpu = A_diagonal_blocks_pinned A_lower_diagonal_blocks_cpu = A_lower_diagonal_blocks_pinned - A_lower_arrow_blocks_cpu = A_lower_arrow_blocks_pinned + A_arrow_bottom_blocks_cpu = A_lower_arrow_blocks_pinned A_arrow_tip_block_cpu = A_arrow_tip_block_pinned B = B_pinned From 2a048b04f20faf3c25429f2073c7aecdfc174953 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 11:47:35 +0000 Subject: [PATCH 226/518] import cupyx --- .../streamlined_sequential_pobtax_gpu.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index 5aef4518..b5b1833c 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -6,6 +6,7 @@ import numpy as np import cupy as cp from cupy.cuda.nvtx import RangePush, RangePop +import cupyx as cpx from serinv.algs import pobtaf, pobtas, pobtasi @@ -118,15 +119,15 @@ def sequential_dataset( toc = time.perf_counter() print(f"Init device arrays took: {toc - tic:.5f} sec", flush=True) - A_diagonal_blocks_pinned = cp.zeros_like_pinned(A_diagonal_blocks_cpu) + A_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_diagonal_blocks_cpu) A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks_cpu[:, :, :] - A_lower_diagonal_blocks_pinned = cp.zeros_like_pinned(A_lower_diagonal_blocks_cpu) + A_lower_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_lower_diagonal_blocks_cpu) A_lower_diagonal_blocks_pinned[:, :, :] = A_lower_diagonal_blocks_cpu[:, :, :] - A_lower_arrow_blocks_pinned = cp.zeros_like_pinned(A_arrow_bottom_blocks_cpu) + A_lower_arrow_blocks_pinned = cpx.zeros_like_pinned(A_arrow_bottom_blocks_cpu) A_lower_arrow_blocks_pinned[:, :, :] = A_arrow_bottom_blocks_cpu[:, :, :] - A_arrow_tip_block_pinned = cp.zeros_like_pinned(A_arrow_tip_block_cpu) + A_arrow_tip_block_pinned = cpx.zeros_like_pinned(A_arrow_tip_block_cpu) A_arrow_tip_block_pinned[:, :] = A_arrow_tip_block_cpu[:, :] - B_pinned = cp.zeros_like_pinned(B_cpu) + B_pinned = cpx.zeros_like_pinned(B_cpu) B_pinned[:, :] = B_cpu[:, :] A_diagonal_blocks_cpu = A_diagonal_blocks_pinned From 5bc3bcbe62e3b08c6aec6163fa858302fba0efb2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 11:59:46 +0000 Subject: [PATCH 227/518] missing B_cpu --- .../positive_definite/streamlined_sequential_pobtax_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index b5b1833c..6eefb168 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -134,7 +134,7 @@ def sequential_dataset( A_lower_diagonal_blocks_cpu = A_lower_diagonal_blocks_pinned A_arrow_bottom_blocks_cpu = A_lower_arrow_blocks_pinned A_arrow_tip_block_cpu = A_arrow_tip_block_pinned - B = B_pinned + B_cpu = B_pinned t_pobtaf = [] t_pobtas = [] From 1cb143ccd529a6359d2287c08f111e52f075caca Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 12:33:38 +0000 Subject: [PATCH 228/518] changed nvtx --- .../streamlined_sequential_pobtax_gpu.py | 4 ++-- src/serinv/algs/pobtas.py | 22 ++++++++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index 6eefb168..99e325f9 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -180,7 +180,7 @@ def sequential_dataset( #print(f"Copying data from GPU took: {toc - tic:.5f} sec", flush=True) cp.cuda.runtime.deviceSynchronize() - RangePush(f"pobtas: i:{i}") + # RangePush(f"pobtas: i:{i}") tic = time.perf_counter() pobtas( A_diagonal_blocks_cpu, @@ -192,7 +192,7 @@ def sequential_dataset( ) cp.cuda.runtime.deviceSynchronize() toc = time.perf_counter() - RangePop() + # RangePop() elapsed = toc - tic print(f"pobtas took: {elapsed:.5f} sec", flush=True) if i >= n_warmups: diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index a6a528f6..07cb064b 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -7,6 +7,9 @@ _get_module_from_str, ) +from cupy.cuda.nvtx import RangePush, RangePop + + def pobtas( L_diagonal_blocks: ArrayLike, @@ -239,7 +242,7 @@ def _pobtas_streaming( "Host<->Device streaming only works when host-arrays are given." ) - print("streaming") + cp, cu_la = _get_module_from_str(module_str="cupy") @@ -280,7 +283,7 @@ def _pobtas_streaming( if trans == "N": # ----- Forward substitution ----- - + RangePush(f"pobtas: startup") # Delete helper variable del B_shape @@ -327,27 +330,30 @@ def _pobtas_streaming( L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) h2d_lower_diagonal_events[0].record(stream=h2d_stream) + RangePop() # --- Computations --- for i in range(0, n_diag_blocks - 1): # pass next B block h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) + RangePush(f"pobtas: streaming B {i+1}") B_d[(i + 1) % 2].set( arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream = h2d_stream ) + RangePop() h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) if i + 1 < n_diag_blocks - 1: # pass next diagonal block h2d_stream.wait_event(compute_current_B_events[(i + 1) % 2]) - + RangePush(f"pobtas: streaming diag blocks {i+1}") L_diagonal_blocks_d[(i + 1) % 2].set( arr=L_diagonal_blocks[i + 1], stream=h2d_stream ) - + RangePop() h2d_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) @@ -377,12 +383,12 @@ def _pobtas_streaming( if i + 1 < n_diag_blocks - 1: # Pass next lower diagonal block h2d_stream.wait_event(compute_next_B_events[(i + 1) % 2]) - + RangePush(f"pobtas: streaming lower diag blocks {i+1}") L_lower_diagonal_blocks_d[(i + 1) % 2].set( arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream ) - + RangePop() h2d_lower_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) with compute_stream: @@ -399,12 +405,12 @@ def _pobtas_streaming( if i + 1 < n_diag_blocks - 1: # Pass next lower arrow block h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) - + RangePush(f"pobtas: streaming lower arrow blocks{i}") L_lower_arrow_blocks_d[(i + 1) % 2].set( arr=L_lower_arrow_blocks[i + 1], stream=h2d_stream ) - + RangePop() h2d_arrow_events[(i + 1) % 2].record(stream=h2d_stream) with compute_stream: From 35e5bf928242c415a35013b6b0f0c098a2feda3a Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 12:40:55 +0000 Subject: [PATCH 229/518] moved pop --- src/serinv/algs/pobtas.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 07cb064b..1446755d 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -342,8 +342,8 @@ def _pobtas_streaming( stream = h2d_stream ) - RangePop() h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) + RangePop() if i + 1 < n_diag_blocks - 1: # pass next diagonal block @@ -353,8 +353,9 @@ def _pobtas_streaming( arr=L_diagonal_blocks[i + 1], stream=h2d_stream ) - RangePop() + h2d_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) + RangePop() with compute_stream: @@ -388,8 +389,9 @@ def _pobtas_streaming( arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream ) - RangePop() + h2d_lower_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) + RangePop() with compute_stream: # Update next B block @@ -410,8 +412,9 @@ def _pobtas_streaming( arr=L_lower_arrow_blocks[i + 1], stream=h2d_stream ) - RangePop() + h2d_arrow_events[(i + 1) % 2].record(stream=h2d_stream) + RangePop() with compute_stream: # Update arrow tip From 01f4b24ff6bb354950cded62efb7e7ccd9285b0f Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 15 May 2025 14:57:44 +0000 Subject: [PATCH 230/518] untangled streaming --- .../streamlined_sequential_pobtax_gpu.py | 17 ++++++++------- src/serinv/algs/pobtas.py | 21 +++++++++++-------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py index 99e325f9..ddb479a2 100644 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py @@ -136,6 +136,7 @@ def sequential_dataset( A_arrow_tip_block_cpu = A_arrow_tip_block_pinned B_cpu = B_pinned + t_pobtaf = [] t_pobtas = [] t_pobtasi = [] @@ -143,14 +144,14 @@ def sequential_dataset( for i in range(n_warmups + n_iterations): print(f"Iteration: {i+1}/{n_warmups+n_iterations}", flush=True) - tic = time.perf_counter() - A_diagonal_blocks_gpu.set(arr=A_diagonal_blocks_cpu) - A_lower_diagonal_blocks_gpu.set(arr=A_lower_diagonal_blocks_cpu) - A_arrow_bottom_blocks_gpu.set(arr=A_arrow_bottom_blocks_cpu) - A_arrow_tip_block_gpu.set(arr=A_arrow_tip_block_cpu) - B_gpu.set(arr=B_cpu) - toc = time.perf_counter() - print(f"Copying data to GPU took: {toc - tic:.5f} sec", flush=True) + #tic = time.perf_counter() + #A_diagonal_blocks_gpu.set(arr=A_diagonal_blocks_cpu) + #A_lower_diagonal_blocks_gpu.set(arr=A_lower_diagonal_blocks_cpu) + #A_arrow_bottom_blocks_gpu.set(arr=A_arrow_bottom_blocks_cpu) + #A_arrow_tip_block_gpu.set(arr=A_arrow_tip_block_cpu) + #B_gpu.set(arr=B_cpu) + #toc = time.perf_counter() + #print(f"Copying data to GPU took: {toc - tic:.5f} sec", flush=True) cp.cuda.runtime.deviceSynchronize() RangePush(f"pobtaf: i:{i}") diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 1446755d..d8f20d1b 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -371,15 +371,7 @@ def _pobtas_streaming( compute_current_B_events[i % 2].record(stream=compute_stream) # Pass current B block back - d2h_stream.wait_event(compute_current_B_events[i % 2]) - - B_d[i % 2].get( - out=B[i * diag_blocksize : (i + 1) * diag_blocksize], - stream=d2h_stream, - blocking=False, - ) - - d2h_B_events[i % 2].record(stream=d2h_stream) + if i + 1 < n_diag_blocks - 1: # Pass next lower diagonal block @@ -392,6 +384,17 @@ def _pobtas_streaming( h2d_lower_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) RangePop() + + d2h_stream.wait_event(compute_current_B_events[i % 2]) + d2h_stream.wait_event(h2d_lower_diagonal_events[(i+1) % 2]) + + B_d[i % 2].get( + out=B[i * diag_blocksize : (i + 1) * diag_blocksize], + stream=d2h_stream, + blocking=False, + ) + + d2h_B_events[i % 2].record(stream=d2h_stream) with compute_stream: # Update next B block From c63cd2c0da0e075203c72e2930d011c6e674e8dd Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 16 May 2025 07:32:40 +0000 Subject: [PATCH 231/518] modified tests --- tests/tests_algs/regular/tests_bta/test_pobtas.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index e9ce2384..54a0bde5 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -3,15 +3,29 @@ import numpy as np import pytest +from conftest import ARRAY_TYPE as ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bta_dense_to_arrays, dd_bta, symmetrize, rhs from serinv.algs import pobtaf, pobtas +if backend_flags["cupy_avail"]: + ARRAY_TYPE.extend( + [ + + pytest.param("streaming", id="streaming"), + ] + ) + if backend_flags["cupy_avail"]: import cupyx as cpx +@pytest.fixture(params=ARRAY_TYPE, autouse=True) +def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() @pytest.mark.parametrize("n_rhs", [1, 2, 3]) def test_pobtas( @@ -22,7 +36,6 @@ def test_pobtas( array_type: str, dtype: np.dtype, ): - array_type = "streaming" A = dd_bta( diagonal_blocksize, From a3905e277b4ee8dc0a12e457116d0bff96f6eb79 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 16 May 2025 07:34:34 +0000 Subject: [PATCH 232/518] pytest array_type override --- tests/tests_algs/regular/tests_bta/test_pobtas.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index 54a0bde5..76e50089 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -3,17 +3,19 @@ import numpy as np import pytest -from conftest import ARRAY_TYPE as ARRAY_TYPE - from serinv import backend_flags, _get_module_from_array from ....testing_utils import bta_dense_to_arrays, dd_bta, symmetrize, rhs from serinv.algs import pobtaf, pobtas + +ARRAY_TYPE = [ + pytest.param("host", id="host"), +] if backend_flags["cupy_avail"]: ARRAY_TYPE.extend( [ - + pytest.param("device", id="device"), pytest.param("streaming", id="streaming"), ] ) From 6913677f503f858247e40f1f90bade9487feb277 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 16 May 2025 07:37:36 +0000 Subject: [PATCH 233/518] changed tests a bit to not override --- tests/tests_algs/regular/tests_bta/test_pobtas.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index 76e50089..3685c174 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -3,19 +3,16 @@ import numpy as np import pytest +from ....conftest import ARRAY_TYPE as ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bta_dense_to_arrays, dd_bta, symmetrize, rhs from serinv.algs import pobtaf, pobtas - -ARRAY_TYPE = [ - pytest.param("host", id="host"), -] if backend_flags["cupy_avail"]: ARRAY_TYPE.extend( [ - pytest.param("device", id="device"), pytest.param("streaming", id="streaming"), ] ) From f734dc6df6ba619977e44e3b34ac325985faeec8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 16 May 2025 07:40:08 +0000 Subject: [PATCH 234/518] activate pobtaf streaming in tests --- tests/tests_algs/regular/tests_bta/test_pobtas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index 3685c174..d58e3a19 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -91,6 +91,7 @@ def test_pobtas( A_lower_diagonal_blocks, A_lower_arrow_blocks, A_arrow_tip_block, + device_streaming=True if array_type == "streaming" else False, ) # Forward solve: Y=L^{-1}B From 5924e1b2e8269526ec1f1b24740353e7fdf8b9f6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 16 May 2025 07:43:51 +0000 Subject: [PATCH 235/518] removed nvtx and tests the tests --- src/serinv/algs/pobtas.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index d8f20d1b..9225bd7b 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -7,9 +7,6 @@ _get_module_from_str, ) -from cupy.cuda.nvtx import RangePush, RangePop - - def pobtas( L_diagonal_blocks: ArrayLike, @@ -51,6 +48,9 @@ def pobtas( else: # Natural arrowhead if device_streaming: + raise NotImplementedError( + "Test testing." + ) _pobtas_streaming( L_diagonal_blocks, L_lower_diagonal_blocks, @@ -283,7 +283,6 @@ def _pobtas_streaming( if trans == "N": # ----- Forward substitution ----- - RangePush(f"pobtas: startup") # Delete helper variable del B_shape @@ -330,12 +329,12 @@ def _pobtas_streaming( L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) h2d_lower_diagonal_events[0].record(stream=h2d_stream) - RangePop() + # --- Computations --- for i in range(0, n_diag_blocks - 1): # pass next B block h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) - RangePush(f"pobtas: streaming B {i+1}") + B_d[(i + 1) % 2].set( arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], @@ -343,19 +342,17 @@ def _pobtas_streaming( ) h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) - RangePop() + if i + 1 < n_diag_blocks - 1: # pass next diagonal block h2d_stream.wait_event(compute_current_B_events[(i + 1) % 2]) - RangePush(f"pobtas: streaming diag blocks {i+1}") L_diagonal_blocks_d[(i + 1) % 2].set( arr=L_diagonal_blocks[i + 1], stream=h2d_stream ) h2d_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) - RangePop() with compute_stream: @@ -376,14 +373,12 @@ def _pobtas_streaming( if i + 1 < n_diag_blocks - 1: # Pass next lower diagonal block h2d_stream.wait_event(compute_next_B_events[(i + 1) % 2]) - RangePush(f"pobtas: streaming lower diag blocks {i+1}") L_lower_diagonal_blocks_d[(i + 1) % 2].set( arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream ) h2d_lower_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) - RangePop() d2h_stream.wait_event(compute_current_B_events[i % 2]) d2h_stream.wait_event(h2d_lower_diagonal_events[(i+1) % 2]) @@ -410,14 +405,12 @@ def _pobtas_streaming( if i + 1 < n_diag_blocks - 1: # Pass next lower arrow block h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) - RangePush(f"pobtas: streaming lower arrow blocks{i}") L_lower_arrow_blocks_d[(i + 1) % 2].set( arr=L_lower_arrow_blocks[i + 1], stream=h2d_stream ) h2d_arrow_events[(i + 1) % 2].record(stream=h2d_stream) - RangePop() with compute_stream: # Update arrow tip From c259ccaa895171bd88839d68ba88cd0e183376d4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 16 May 2025 07:44:28 +0000 Subject: [PATCH 236/518] removed test testing --- src/serinv/algs/pobtas.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 9225bd7b..1575a4ac 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -48,9 +48,6 @@ def pobtas( else: # Natural arrowhead if device_streaming: - raise NotImplementedError( - "Test testing." - ) _pobtas_streaming( L_diagonal_blocks, L_lower_diagonal_blocks, From 4cf986f18000a910a23782f1f8a09f1e71d056b1 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 16 May 2025 07:48:25 +0000 Subject: [PATCH 237/518] expanded tests --- tests/tests_algs/regular/tests_bt/test_pobts.py | 9 +++++++++ tests/tests_algs/regular/tests_bta/test_pobtas.py | 6 +++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/tests_algs/regular/tests_bt/test_pobts.py b/tests/tests_algs/regular/tests_bt/test_pobts.py index f5c941dc..0f9835a6 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobts.py +++ b/tests/tests_algs/regular/tests_bt/test_pobts.py @@ -3,11 +3,20 @@ import numpy as np import pytest +from ....conftest import ARRAY_TYPE as ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize, rhs from serinv.algs import pobtf, pobts +if backend_flags["cupy_avail"]: + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + if backend_flags["cupy_avail"]: import cupyx as cpx diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index d58e3a19..a94040f0 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -21,9 +21,9 @@ import cupyx as cpx -@pytest.fixture(params=ARRAY_TYPE, autouse=True) -def array_type(request: pytest.FixtureRequest) -> str: - return request.param +#@pytest.fixture(params=ARRAY_TYPE, autouse=True) +#def array_type(request: pytest.FixtureRequest) -> str: +# return request.param @pytest.mark.mpi_skip() @pytest.mark.parametrize("n_rhs", [1, 2, 3]) From 632b74ccdb3efcef5fee12924179e8f575c2c12f Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 16 May 2025 07:49:51 +0000 Subject: [PATCH 238/518] expanded tests further --- tests/tests_algs/regular/tests_bt/test_pobts.py | 5 +++++ tests/tests_algs/regular/tests_bta/test_pobtas.py | 6 +++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/tests_algs/regular/tests_bt/test_pobts.py b/tests/tests_algs/regular/tests_bt/test_pobts.py index 0f9835a6..f474c6b4 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobts.py +++ b/tests/tests_algs/regular/tests_bt/test_pobts.py @@ -20,6 +20,11 @@ if backend_flags["cupy_avail"]: import cupyx as cpx + +@pytest.fixture(params=ARRAY_TYPE, autouse=True) +def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() @pytest.mark.parametrize("n_rhs", [1, 2, 3]) def test_pobts( diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index a94040f0..d58e3a19 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -21,9 +21,9 @@ import cupyx as cpx -#@pytest.fixture(params=ARRAY_TYPE, autouse=True) -#def array_type(request: pytest.FixtureRequest) -> str: -# return request.param +@pytest.fixture(params=ARRAY_TYPE, autouse=True) +def array_type(request: pytest.FixtureRequest) -> str: + return request.param @pytest.mark.mpi_skip() @pytest.mark.parametrize("n_rhs", [1, 2, 3]) From 680d8990654eb737956a92b619eb98e6967944b2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 16 May 2025 07:51:20 +0000 Subject: [PATCH 239/518] activated streaming tests for pobtaf --- tests/tests_algs/regular/tests_bta/test_pobtaf.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtaf.py b/tests/tests_algs/regular/tests_bta/test_pobtaf.py index a30b9094..98756357 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtaf.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtaf.py @@ -3,15 +3,28 @@ import numpy as np import pytest +from ....conftest import ARRAY_TYPE as ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bta_dense_to_arrays, dd_bta, symmetrize from serinv.algs import pobtaf +if backend_flags["cupy_avail"]: + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + if backend_flags["cupy_avail"]: import cupyx as cpx +@pytest.fixture(params=ARRAY_TYPE, autouse=True) +def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() def test_pobtaf( diagonal_blocksize: int, From 10de2c5f37536c12cd5f96956b0af0ae061eef4a Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 16 May 2025 08:04:03 +0000 Subject: [PATCH 240/518] removed leftover cscs scripts --- run_streamlined_sequential_pobtax_gpu.sh | 51 ---- ...d_sequential_pobtax_gpu.sh:Zone.Identifier | 3 - .../streamlined_sequential_pobtax_gpu.py | 234 ------------------ ...d_sequential_pobtax_gpu.py:Zone.Identifier | 3 - 4 files changed, 291 deletions(-) delete mode 100644 run_streamlined_sequential_pobtax_gpu.sh delete mode 100644 sc25_runs/positive_definite/run_streamlined_sequential_pobtax_gpu.sh:Zone.Identifier delete mode 100644 sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py delete mode 100644 sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py:Zone.Identifier diff --git a/run_streamlined_sequential_pobtax_gpu.sh b/run_streamlined_sequential_pobtax_gpu.sh deleted file mode 100644 index 74b5d9db..00000000 --- a/run_streamlined_sequential_pobtax_gpu.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash -l -#SBATCH --job-name="serinv_pobtx_benchmark" -#SBATCH --output=%x.%j.out -#SBATCH --error=%x.%j.err -#SBATCH --account=lp16 -#SBATCH --time=00:10:00 -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=64 -#SBATCH --gpus-per-task=1 -#SBATCH --partition=debug -#SBATCH --constraint=gpu -#SBATCH --hint=nomultithread -#SBATCH --uenv=prgenv-gnu/24.11:v1 -#SBATCH --view=modules - -set -e -u - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export MPICH_GPU_SUPPORT_ENABLED=1 -export OMP_PLACES=cores -export OMP_PROC_BIND=close - -export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID - -# source ~/load_modules.sh -conda activate serinv_env - -# Dataset 1: b = 1675, a = 6, n = 128 -# Reference timings (to beat!): -# - pobtaf: 0.38959 -# - pobtas: 0.02415 -# - pobtasi: 0.29593 -# export b=1675 -# export a=6 -# export n=128 - -# Dataset 2: b = 4002, a = 6, n = 250 -# Reference timings (to beat!): -# - pobtaf: 3.2716 (INLA_BTA CUDA code: 2.713) -# - pobtas: 0.15397 -# - pobtasi: 5.15729 -export b=4002 -export a=6 -export n=250 - -# Benchmark the code -srun python ~/serinv/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py --b $b --a $a --n $n - -# Profile the code -# srun nsys profile --force-overwrite=true -o profile_serinv_pobtax_b${b}_a${a}_n${n} python ~/repositories/serinv/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py --b $b --a $a --n $n --b $b --a $a --n $n \ No newline at end of file diff --git a/sc25_runs/positive_definite/run_streamlined_sequential_pobtax_gpu.sh:Zone.Identifier b/sc25_runs/positive_definite/run_streamlined_sequential_pobtax_gpu.sh:Zone.Identifier deleted file mode 100644 index 33e02d64..00000000 --- a/sc25_runs/positive_definite/run_streamlined_sequential_pobtax_gpu.sh:Zone.Identifier +++ /dev/null @@ -1,3 +0,0 @@ -[ZoneTransfer] -ZoneId=3 -HostUrl=https://iis-mattermost.ee.ethz.ch/api/v4/files/waiggpk1miyeb84dcahdh53b1e?download=1 diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py deleted file mode 100644 index ddb479a2..00000000 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py +++ /dev/null @@ -1,234 +0,0 @@ -import time - -tic = time.perf_counter() -import argparse - -import numpy as np -import cupy as cp -from cupy.cuda.nvtx import RangePush, RangePop -import cupyx as cpx - -from serinv.algs import pobtaf, pobtas, pobtasi - - -def sequential_dataset( - n_blocks: int, - diagonal_blocksize: int, - arrowhead_blocksize: int, -): - A_diagonal_blocks = np.random.rand(n_blocks, diagonal_blocksize, diagonal_blocksize) - A_lower_diagonal_blocks = np.random.rand( - n_blocks - 1, diagonal_blocksize, diagonal_blocksize - ) - A_arrow_bottom_blocks = np.random.rand( - n_blocks, arrowhead_blocksize, diagonal_blocksize - ) - A_arrow_tip_block = np.random.rand(arrowhead_blocksize, arrowhead_blocksize) - - # CODE TO MODIFY - arrow_colsum = np.zeros((arrowhead_blocksize), dtype=A_diagonal_blocks.dtype) - for i in range(A_diagonal_blocks.shape[0]): - colsum = np.sum(A_diagonal_blocks[i, :, :], axis=1) - np.diag( - A_diagonal_blocks[i, :, :] - ) - if i > 0: - colsum += np.sum(A_lower_diagonal_blocks[i - 1, :, :], axis=1) - - A_diagonal_blocks[i, :, :] += np.diag(colsum) - - arrow_colsum[:] += np.sum(A_arrow_bottom_blocks[i, :, :], axis=1) - - A_arrow_tip_block[:, :] += np.diag( - arrow_colsum + np.sum(A_arrow_tip_block[:, :], axis=1) - ) - - return ( - A_diagonal_blocks, - A_lower_diagonal_blocks, - A_arrow_bottom_blocks, - A_arrow_tip_block, - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Process some integers.") - parser.add_argument( - "--b", - type=int, - default=128, - help="an integer for the diagonal block size", - ) - parser.add_argument( - "--a", - type=int, - default=0, - help="an integer for the diagonal block size", - ) - parser.add_argument( - "--n", - type=int, - default=8, - help="an integer for the number of diagonal blocks", - ) - args = parser.parse_args() - toc = time.perf_counter() - print(f"Import and parsing took: {toc - tic:.5f} sec", flush=True) - - diagonal_blocksize = args.b - arrowhead_blocksize = args.a - n_blocks = args.n - n_iterations = 10 - n_warmups = 2 - - tic = time.perf_counter() - ( - A_diagonal_blocks_cpu, - A_lower_diagonal_blocks_cpu, - A_arrow_bottom_blocks_cpu, - A_arrow_tip_block_cpu, - ) = sequential_dataset( - n_blocks, - diagonal_blocksize, - arrowhead_blocksize, - ) - B_cpu = np.random.rand(diagonal_blocksize * n_blocks + arrowhead_blocksize, 1) - toc = time.perf_counter() - print(f"Generate dataset took: {toc - tic:.5f} sec", flush=True) - print(f" b = {diagonal_blocksize}", flush=True) - print(f" a = {arrowhead_blocksize}", flush=True) - print(f" n = {n_blocks}", flush=True) - print(f" n_iterations = {n_iterations}", flush=True) - print(f" n_warmups = {n_warmups}", flush=True) - - total_memory = ( - A_diagonal_blocks_cpu.nbytes - + A_lower_diagonal_blocks_cpu.nbytes - + A_arrow_bottom_blocks_cpu.nbytes - + A_arrow_tip_block_cpu.nbytes - + B_cpu.nbytes - ) - print(f" Total memory: {total_memory / 1e9:.5f} GB", flush=True) - - tic = time.perf_counter() - # Init device arrays - A_diagonal_blocks_gpu = cp.empty_like(A_diagonal_blocks_cpu) - A_lower_diagonal_blocks_gpu = cp.empty_like(A_lower_diagonal_blocks_cpu) - A_arrow_bottom_blocks_gpu = cp.empty_like(A_arrow_bottom_blocks_cpu) - A_arrow_tip_block_gpu = cp.empty_like(A_arrow_tip_block_cpu) - B_gpu = cp.empty_like(B_cpu) - toc = time.perf_counter() - print(f"Init device arrays took: {toc - tic:.5f} sec", flush=True) - - A_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_diagonal_blocks_cpu) - A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks_cpu[:, :, :] - A_lower_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_lower_diagonal_blocks_cpu) - A_lower_diagonal_blocks_pinned[:, :, :] = A_lower_diagonal_blocks_cpu[:, :, :] - A_lower_arrow_blocks_pinned = cpx.zeros_like_pinned(A_arrow_bottom_blocks_cpu) - A_lower_arrow_blocks_pinned[:, :, :] = A_arrow_bottom_blocks_cpu[:, :, :] - A_arrow_tip_block_pinned = cpx.zeros_like_pinned(A_arrow_tip_block_cpu) - A_arrow_tip_block_pinned[:, :] = A_arrow_tip_block_cpu[:, :] - B_pinned = cpx.zeros_like_pinned(B_cpu) - B_pinned[:, :] = B_cpu[:, :] - - A_diagonal_blocks_cpu = A_diagonal_blocks_pinned - A_lower_diagonal_blocks_cpu = A_lower_diagonal_blocks_pinned - A_arrow_bottom_blocks_cpu = A_lower_arrow_blocks_pinned - A_arrow_tip_block_cpu = A_arrow_tip_block_pinned - B_cpu = B_pinned - - - t_pobtaf = [] - t_pobtas = [] - t_pobtasi = [] - - for i in range(n_warmups + n_iterations): - print(f"Iteration: {i+1}/{n_warmups+n_iterations}", flush=True) - - #tic = time.perf_counter() - #A_diagonal_blocks_gpu.set(arr=A_diagonal_blocks_cpu) - #A_lower_diagonal_blocks_gpu.set(arr=A_lower_diagonal_blocks_cpu) - #A_arrow_bottom_blocks_gpu.set(arr=A_arrow_bottom_blocks_cpu) - #A_arrow_tip_block_gpu.set(arr=A_arrow_tip_block_cpu) - #B_gpu.set(arr=B_cpu) - #toc = time.perf_counter() - #print(f"Copying data to GPU took: {toc - tic:.5f} sec", flush=True) - - cp.cuda.runtime.deviceSynchronize() - RangePush(f"pobtaf: i:{i}") - tic = time.perf_counter() - pobtaf( - A_diagonal_blocks_cpu, - A_lower_diagonal_blocks_cpu, - A_arrow_bottom_blocks_cpu, - A_arrow_tip_block_cpu, - device_streaming=True - ) - cp.cuda.runtime.deviceSynchronize() - toc = time.perf_counter() - RangePop() - elapsed = toc - tic - print(f"pobtaf took: {elapsed:.5f} sec", flush=True) - if i >= n_warmups: - t_pobtaf.append(elapsed) - - #tic = time.perf_counter() - #A_diagonal_blocks_gpu.get(out=A_diagonal_blocks_cpu) - #A_lower_diagonal_blocks_gpu.get(out=A_lower_diagonal_blocks_cpu) - #A_arrow_bottom_blocks_gpu.get(out=A_arrow_bottom_blocks_cpu) - #A_arrow_tip_block_gpu.get(out=A_arrow_tip_block_cpu) - #B_gpu.get(out=B_cpu) - #toc = time.perf_counter() - #print(f"Copying data from GPU took: {toc - tic:.5f} sec", flush=True) - - cp.cuda.runtime.deviceSynchronize() - # RangePush(f"pobtas: i:{i}") - tic = time.perf_counter() - pobtas( - A_diagonal_blocks_cpu, - A_lower_diagonal_blocks_cpu, - A_arrow_bottom_blocks_cpu, - A_arrow_tip_block_cpu, - B_cpu, - device_streaming=True - ) - cp.cuda.runtime.deviceSynchronize() - toc = time.perf_counter() - # RangePop() - elapsed = toc - tic - print(f"pobtas took: {elapsed:.5f} sec", flush=True) - if i >= n_warmups: - t_pobtas.append(elapsed) - - tic = time.perf_counter() - A_diagonal_blocks_gpu.set(arr=A_diagonal_blocks_cpu) - A_lower_diagonal_blocks_gpu.set(arr=A_lower_diagonal_blocks_cpu) - A_arrow_bottom_blocks_gpu.set(arr=A_arrow_bottom_blocks_cpu) - A_arrow_tip_block_gpu.set(arr=A_arrow_tip_block_cpu) - B_gpu.set(arr=B_cpu) - toc = time.perf_counter() - print(f"Copying data to GPU took: {toc - tic:.5f} sec", flush=True) - - cp.cuda.runtime.deviceSynchronize() - RangePush(f"pobtasi: i:{i}") - tic = time.perf_counter() - pobtasi( - A_diagonal_blocks_gpu, - A_lower_diagonal_blocks_gpu, - A_arrow_bottom_blocks_gpu, - A_arrow_tip_block_gpu, - ) - cp.cuda.runtime.deviceSynchronize() - toc = time.perf_counter() - RangePop() - elapsed = toc - tic - print(f"pobtasi took: {elapsed:.5f} sec", flush=True) - if i >= n_warmups: - t_pobtasi.append(elapsed) - - print(f"t_pobtaf: {t_pobtaf}", flush=True) - print(f"t_pobtas: {t_pobtas}", flush=True) - print(f"t_pobtasi: {t_pobtasi}", flush=True) - - print(f"avg t_pobtaf: {np.mean(np.array(t_pobtaf)):.5f} sec", flush=True) - print(f"avg t_pobtas: {np.mean(np.array(t_pobtas)):.5f} sec", flush=True) - print(f"avg t_pobtasi: {np.mean(np.array(t_pobtasi)):.5f} sec", flush=True) \ No newline at end of file diff --git a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py:Zone.Identifier b/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py:Zone.Identifier deleted file mode 100644 index ce8dec59..00000000 --- a/sc25_runs/positive_definite/streamlined_sequential_pobtax_gpu.py:Zone.Identifier +++ /dev/null @@ -1,3 +0,0 @@ -[ZoneTransfer] -ZoneId=3 -HostUrl=https://iis-mattermost.ee.ethz.ch/api/v4/files/fw5m5tapefbi8deseto5qqro9w?download=1 From 7390a6b56c8f59efab0850366f331ad239501350 Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 16 May 2025 08:22:14 +0000 Subject: [PATCH 241/518] removed line that forced streaming --- tests/tests_algs/regular/tests_bt/test_pobts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests_algs/regular/tests_bt/test_pobts.py b/tests/tests_algs/regular/tests_bt/test_pobts.py index f474c6b4..d137c796 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobts.py +++ b/tests/tests_algs/regular/tests_bt/test_pobts.py @@ -34,7 +34,6 @@ def test_pobts( array_type: str, dtype: np.dtype, ): - array_type = "streaming" A = dd_bt( diagonal_blocksize, From c8f89e2e351fc7519e284051bd41b590eba4f6d1 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 27 May 2025 14:46:12 +0000 Subject: [PATCH 242/518] first modification to get cupy and scipy implementations for trsm right and left hand side --- src/serinv/utils/trsm_solve_device.py | 103 ++++++++++++++++++++++ src/serinv/utils/trsm_solve_host.py | 122 ++++++++++++++++++++++++++ 2 files changed, 225 insertions(+) create mode 100644 src/serinv/utils/trsm_solve_device.py create mode 100644 src/serinv/utils/trsm_solve_host.py diff --git a/src/serinv/utils/trsm_solve_device.py b/src/serinv/utils/trsm_solve_device.py new file mode 100644 index 00000000..9d358a1c --- /dev/null +++ b/src/serinv/utils/trsm_solve_device.py @@ -0,0 +1,103 @@ +import numpy + +from cupy.cuda import cublas +from cupy.cuda import device +from cupy.linalg import _util + + +def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, check_finite=False, aplha = 1., side=0): + """Solve the equation a x = b for x, assuming a is a triangular matrix. + + Args: + a (cupy.ndarray): The matrix with dimension ``(M, M)``. + b (cupy.ndarray): The matrix with dimension ``(M,)`` or + ``(M, N)``. + lower (bool): Use only data contained in the lower triangle of ``a``. + Default is to use upper triangle. + trans (0, 1, 2, 'N', 'T' or 'C'): Type of system to solve: + + - *'0'* or *'N'* -- :math:`a x = b` + - *'1'* or *'T'* -- :math:`a^T x = b` + - *'2'* or *'C'* -- :math:`a^H x = b` + + unit_diagonal (bool): If ``True``, diagonal elements of ``a`` are + assumed to be 1 and will not be referenced. + overwrite_b (bool): Allow overwriting data in b (may enhance + performance) + check_finite (bool): Whether to check that the input matrices contain + only finite numbers. Disabling may give a performance gain, but may + result in problems (crashes, non-termination) if the inputs do + contain infinities or NaNs. + + Returns: + cupy.ndarray: + The matrix with dimension ``(M,)`` or ``(M, N)``. + + .. seealso:: :func:`scipy.linalg.solve_triangular` + """ + + _util._assert_cupy_array(a, b) + + if len(a.shape) != 2 or a.shape[0] != a.shape[1]: + raise ValueError('expected square matrix') + if len(a) != len(b): + raise ValueError('incompatible dimensions') + + # Cast to float32 or float64 + if a.dtype.char in 'fd': + dtype = a.dtype + else: + dtype = numpy.promote_types(a.dtype.char, 'f') + + a = cupy.array(a, dtype=dtype, order='F', copy=False) + b = cupy.array(b, dtype=dtype, order='F', copy=(not overwrite_b)) + + if check_finite: + if a.dtype.kind == 'f' and not cupy.isfinite(a).all(): + raise ValueError( + 'array must not contain infs or NaNs') + if b.dtype.kind == 'f' and not cupy.isfinite(b).all(): + raise ValueError( + 'array must not contain infs or NaNs') + + m, n = (b.size, 1) if b.ndim == 1 else b.shape + cublas_handle = device.get_cublas_handle() + + if dtype == 'f': + trsm = cublas.strsm + elif dtype == 'd': + trsm = cublas.dtrsm + elif dtype == 'F': + trsm = cublas.ctrsm + else: # dtype == 'D' + trsm = cublas.ztrsm + one = numpy.array(1, dtype=dtype) + + if lower: + uplo = cublas.CUBLAS_FILL_MODE_LOWER + else: + uplo = cublas.CUBLAS_FILL_MODE_UPPER + + if trans == 'N': + trans = cublas.CUBLAS_OP_N + elif trans == 'T': + trans = cublas.CUBLAS_OP_T + elif trans == 'C': + trans = cublas.CUBLAS_OP_C + + if unit_diagonal: + diag = cublas.CUBLAS_DIAG_UNIT + else: + diag = cublas.CUBLAS_DIAG_NON_UNIT + + if side: + blas_side = cublas.CUBLAS_SIDE_RIGHT + else: + blas_side = cublas.CUBLAS_SIDE_LEFT + + trsm( + cublas_handle, blas_side, uplo, + trans, diag, + m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) + return b \ No newline at end of file diff --git a/src/serinv/utils/trsm_solve_host.py b/src/serinv/utils/trsm_solve_host.py new file mode 100644 index 00000000..820770e4 --- /dev/null +++ b/src/serinv/utils/trsm_solve_host.py @@ -0,0 +1,122 @@ +import numpy as np + + +from scipy.linalg.blas import get_blas_funcs +from scipy.linalg._misc import _datacopied +from scipy.linalg._decomp import _asarray_validated + +def solve_triangular_host(a, b, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, check_finite=True, side=0): + """ + Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. + + Parameters + ---------- + a : (M, M) array_like + A triangular matrix + b : (M,) or (M, N) array_like + Right-hand side matrix in ``a x = b`` + lower : bool, optional + Use only data contained in the lower triangle of `a`. + Default is to use upper triangle. + trans : {0, 1, 2, 'N', 'T', 'C'}, optional + Type of system to solve: + + ======== ========= + trans system + ======== ========= + 0 or 'N' a x = b + 1 or 'T' a^T x = b + 2 or 'C' a^H x = b + ======== ========= + unit_diagonal : bool, optional + If True, diagonal elements of `a` are assumed to be 1 and + will not be referenced. + overwrite_b : bool, optional + Allow overwriting data in `b` (may enhance performance) + check_finite : bool, optional + Whether to check that the input matrices contain only finite numbers. + Disabling may give a performance gain, but may result in problems + (crashes, non-termination) if the inputs do contain infinities or NaNs. + + Returns + ------- + x : (M,) or (M, N) ndarray + Solution to the system ``a x = b``. Shape of return matches `b`. + + Raises + ------ + LinAlgError + If `a` is singular + + Notes + ----- + .. versionadded:: 0.9.0 + + Examples + -------- + Solve the lower triangular system a x = b, where:: + + [3 0 0 0] [4] + a = [2 1 0 0] b = [2] + [1 0 1 0] [4] + [1 1 1 1] [2] + + >>> import numpy as np + >>> from scipy.linalg import solve_triangular + >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]]) + >>> b = np.array([4, 2, 4, 2]) + >>> x = solve_triangular(a, b, lower=True) + >>> x + array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333]) + >>> a.dot(x) # Check the result + array([ 4., 2., 4., 2.]) + + """ + + a1 = _asarray_validated(a, check_finite=check_finite) + b1 = _asarray_validated(b, check_finite=check_finite) + + if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: + raise ValueError('expected square matrix') + + if a1.shape[0] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + # accommodate empty arrays + if b1.size == 0: + dt_nonempty = solve_triangular_host( + np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) + ).dtype + return np.empty_like(b1, dtype=dt_nonempty) + + overwrite_b = overwrite_b or _datacopied(b1, b) + + x = _solve_triangular(a1, b1, trans, lower, unit_diagonal, overwrite_b, side) + return x + + +# solve_triangular without the input validation +def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, side=0): + + trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) + trsm, = get_blas_funcs(('trsm',), (a1, b1)) + + if a1.dtype.char in 'fd': + dtype = a1.dtype + else: + dtype = np.promote_types(a1.dtype.char, 'f') + + one = np.array(1, dtype=dtype) + alpha = one.ctypes.data + + if a1.flags.f_contiguous or trans == 2: + x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=lower, + trans_a=trans, diag=unit_diagonal, side=side) + else: + # transposed system is solved since trtrs expects Fortran ordering + x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, + trans_a=not trans, diag=unit_diagonal, side=side) + + return x From cbd07cdb59657d040939cbf6ff6d908bdab2ca09 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 27 May 2025 14:51:09 +0000 Subject: [PATCH 243/518] moved improvement files to new branch --- src/serinv/utils/trsm_solve_device.py | 103 ---------------------- src/serinv/utils/trsm_solve_host.py | 122 -------------------------- 2 files changed, 225 deletions(-) delete mode 100644 src/serinv/utils/trsm_solve_device.py delete mode 100644 src/serinv/utils/trsm_solve_host.py diff --git a/src/serinv/utils/trsm_solve_device.py b/src/serinv/utils/trsm_solve_device.py deleted file mode 100644 index 9d358a1c..00000000 --- a/src/serinv/utils/trsm_solve_device.py +++ /dev/null @@ -1,103 +0,0 @@ -import numpy - -from cupy.cuda import cublas -from cupy.cuda import device -from cupy.linalg import _util - - -def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, - overwrite_b=False, check_finite=False, aplha = 1., side=0): - """Solve the equation a x = b for x, assuming a is a triangular matrix. - - Args: - a (cupy.ndarray): The matrix with dimension ``(M, M)``. - b (cupy.ndarray): The matrix with dimension ``(M,)`` or - ``(M, N)``. - lower (bool): Use only data contained in the lower triangle of ``a``. - Default is to use upper triangle. - trans (0, 1, 2, 'N', 'T' or 'C'): Type of system to solve: - - - *'0'* or *'N'* -- :math:`a x = b` - - *'1'* or *'T'* -- :math:`a^T x = b` - - *'2'* or *'C'* -- :math:`a^H x = b` - - unit_diagonal (bool): If ``True``, diagonal elements of ``a`` are - assumed to be 1 and will not be referenced. - overwrite_b (bool): Allow overwriting data in b (may enhance - performance) - check_finite (bool): Whether to check that the input matrices contain - only finite numbers. Disabling may give a performance gain, but may - result in problems (crashes, non-termination) if the inputs do - contain infinities or NaNs. - - Returns: - cupy.ndarray: - The matrix with dimension ``(M,)`` or ``(M, N)``. - - .. seealso:: :func:`scipy.linalg.solve_triangular` - """ - - _util._assert_cupy_array(a, b) - - if len(a.shape) != 2 or a.shape[0] != a.shape[1]: - raise ValueError('expected square matrix') - if len(a) != len(b): - raise ValueError('incompatible dimensions') - - # Cast to float32 or float64 - if a.dtype.char in 'fd': - dtype = a.dtype - else: - dtype = numpy.promote_types(a.dtype.char, 'f') - - a = cupy.array(a, dtype=dtype, order='F', copy=False) - b = cupy.array(b, dtype=dtype, order='F', copy=(not overwrite_b)) - - if check_finite: - if a.dtype.kind == 'f' and not cupy.isfinite(a).all(): - raise ValueError( - 'array must not contain infs or NaNs') - if b.dtype.kind == 'f' and not cupy.isfinite(b).all(): - raise ValueError( - 'array must not contain infs or NaNs') - - m, n = (b.size, 1) if b.ndim == 1 else b.shape - cublas_handle = device.get_cublas_handle() - - if dtype == 'f': - trsm = cublas.strsm - elif dtype == 'd': - trsm = cublas.dtrsm - elif dtype == 'F': - trsm = cublas.ctrsm - else: # dtype == 'D' - trsm = cublas.ztrsm - one = numpy.array(1, dtype=dtype) - - if lower: - uplo = cublas.CUBLAS_FILL_MODE_LOWER - else: - uplo = cublas.CUBLAS_FILL_MODE_UPPER - - if trans == 'N': - trans = cublas.CUBLAS_OP_N - elif trans == 'T': - trans = cublas.CUBLAS_OP_T - elif trans == 'C': - trans = cublas.CUBLAS_OP_C - - if unit_diagonal: - diag = cublas.CUBLAS_DIAG_UNIT - else: - diag = cublas.CUBLAS_DIAG_NON_UNIT - - if side: - blas_side = cublas.CUBLAS_SIDE_RIGHT - else: - blas_side = cublas.CUBLAS_SIDE_LEFT - - trsm( - cublas_handle, blas_side, uplo, - trans, diag, - m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) - return b \ No newline at end of file diff --git a/src/serinv/utils/trsm_solve_host.py b/src/serinv/utils/trsm_solve_host.py deleted file mode 100644 index 820770e4..00000000 --- a/src/serinv/utils/trsm_solve_host.py +++ /dev/null @@ -1,122 +0,0 @@ -import numpy as np - - -from scipy.linalg.blas import get_blas_funcs -from scipy.linalg._misc import _datacopied -from scipy.linalg._decomp import _asarray_validated - -def solve_triangular_host(a, b, trans=0, lower=False, unit_diagonal=False, - overwrite_b=False, check_finite=True, side=0): - """ - Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. - - Parameters - ---------- - a : (M, M) array_like - A triangular matrix - b : (M,) or (M, N) array_like - Right-hand side matrix in ``a x = b`` - lower : bool, optional - Use only data contained in the lower triangle of `a`. - Default is to use upper triangle. - trans : {0, 1, 2, 'N', 'T', 'C'}, optional - Type of system to solve: - - ======== ========= - trans system - ======== ========= - 0 or 'N' a x = b - 1 or 'T' a^T x = b - 2 or 'C' a^H x = b - ======== ========= - unit_diagonal : bool, optional - If True, diagonal elements of `a` are assumed to be 1 and - will not be referenced. - overwrite_b : bool, optional - Allow overwriting data in `b` (may enhance performance) - check_finite : bool, optional - Whether to check that the input matrices contain only finite numbers. - Disabling may give a performance gain, but may result in problems - (crashes, non-termination) if the inputs do contain infinities or NaNs. - - Returns - ------- - x : (M,) or (M, N) ndarray - Solution to the system ``a x = b``. Shape of return matches `b`. - - Raises - ------ - LinAlgError - If `a` is singular - - Notes - ----- - .. versionadded:: 0.9.0 - - Examples - -------- - Solve the lower triangular system a x = b, where:: - - [3 0 0 0] [4] - a = [2 1 0 0] b = [2] - [1 0 1 0] [4] - [1 1 1 1] [2] - - >>> import numpy as np - >>> from scipy.linalg import solve_triangular - >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]]) - >>> b = np.array([4, 2, 4, 2]) - >>> x = solve_triangular(a, b, lower=True) - >>> x - array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333]) - >>> a.dot(x) # Check the result - array([ 4., 2., 4., 2.]) - - """ - - a1 = _asarray_validated(a, check_finite=check_finite) - b1 = _asarray_validated(b, check_finite=check_finite) - - if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: - raise ValueError('expected square matrix') - - if a1.shape[0] != b1.shape[0]: - raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') - - # accommodate empty arrays - if b1.size == 0: - dt_nonempty = solve_triangular_host( - np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) - ).dtype - return np.empty_like(b1, dtype=dt_nonempty) - - overwrite_b = overwrite_b or _datacopied(b1, b) - - x = _solve_triangular(a1, b1, trans, lower, unit_diagonal, overwrite_b, side) - return x - - -# solve_triangular without the input validation -def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, - overwrite_b=False, side=0): - - trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) - trsm, = get_blas_funcs(('trsm',), (a1, b1)) - - if a1.dtype.char in 'fd': - dtype = a1.dtype - else: - dtype = np.promote_types(a1.dtype.char, 'f') - - one = np.array(1, dtype=dtype) - alpha = one.ctypes.data - - if a1.flags.f_contiguous or trans == 2: - x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=lower, - trans_a=trans, diag=unit_diagonal, side=side) - else: - # transposed system is solved since trtrs expects Fortran ordering - x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, - trans_a=not trans, diag=unit_diagonal, side=side) - - return x From 9dca6a88c0a9c8ff684b0f0935654d6b8a4d6481 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 27 May 2025 14:51:49 +0000 Subject: [PATCH 244/518] first implementation of trsm from cupy and scipy for left and right hand side --- src/serinv/utils/trsm_solve_device.py | 102 +++++++++++++++++++++ src/serinv/utils/trsm_solve_host.py | 122 ++++++++++++++++++++++++++ 2 files changed, 224 insertions(+) create mode 100644 src/serinv/utils/trsm_solve_device.py create mode 100644 src/serinv/utils/trsm_solve_host.py diff --git a/src/serinv/utils/trsm_solve_device.py b/src/serinv/utils/trsm_solve_device.py new file mode 100644 index 00000000..9d6bdb43 --- /dev/null +++ b/src/serinv/utils/trsm_solve_device.py @@ -0,0 +1,102 @@ +import numpy + +from cupy.cuda import cublas +from cupy.cuda import device +from cupy.linalg import _util + + +def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, check_finite=False, aplha = 1., side=0): + """Solve the equation a x = b for x, assuming a is a triangular matrix. + + Args: + a (cupy.ndarray): The matrix with dimension ``(M, M)``. + b (cupy.ndarray): The matrix with dimension ``(M,)`` or + ``(M, N)``. + lower (bool): Use only data contained in the lower triangle of ``a``. + Default is to use upper triangle. + trans (0, 1, 2, 'N', 'T' or 'C'): Type of system to solve: + + - *'0'* or *'N'* -- :math:`a x = b` + - *'1'* or *'T'* -- :math:`a^T x = b` + - *'2'* or *'C'* -- :math:`a^H x = b` + + unit_diagonal (bool): If ``True``, diagonal elements of ``a`` are + assumed to be 1 and will not be referenced. + overwrite_b (bool): Allow overwriting data in b (may enhance + performance) + check_finite (bool): Whether to check that the input matrices contain + only finite numbers. Disabling may give a performance gain, but may + result in problems (crashes, non-termination) if the inputs do + contain infinities or NaNs. + + Returns: + cupy.ndarray: + The matrix with dimension ``(M,)`` or ``(M, N)``. + + .. seealso:: :func:`scipy.linalg.solve_triangular` + """ + + _util._assert_cupy_array(a, b) + + if len(a.shape) != 2 or a.shape[0] != a.shape[1]: + raise ValueError('expected square matrix') + if len(a) != len(b): + raise ValueError('incompatible dimensions') + + # Cast to float32 or float64 + if a.dtype.char in 'fd': + dtype = a.dtype + else: + dtype = numpy.promote_types(a.dtype.char, 'f') + + a = cupy.array(a, dtype=dtype, order='F', copy=False) + b = cupy.array(b, dtype=dtype, order='F', copy=(not overwrite_b)) + + if check_finite: + if a.dtype.kind == 'f' and not cupy.isfinite(a).all(): + raise ValueError( + 'array must not contain infs or NaNs') + if b.dtype.kind == 'f' and not cupy.isfinite(b).all(): + raise ValueError( + 'array must not contain infs or NaNs') + + m, n = (b.size, 1) if b.ndim == 1 else b.shape + cublas_handle = device.get_cublas_handle() + + if dtype == 'f': + trsm = cublas.strsm + elif dtype == 'd': + trsm = cublas.dtrsm + elif dtype == 'F': + trsm = cublas.ctrsm + else: # dtype == 'D' + trsm = cublas.ztrsm + one = numpy.array(1, dtype=dtype) + + if lower: + uplo = cublas.CUBLAS_FILL_MODE_LOWER + else: + uplo = cublas.CUBLAS_FILL_MODE_UPPER + + if trans == 'N': + trans = cublas.CUBLAS_OP_N + elif trans == 'T': + trans = cublas.CUBLAS_OP_T + elif trans == 'C': + trans = cublas.CUBLAS_OP_C + + if unit_diagonal: + diag = cublas.CUBLAS_DIAG_UNIT + else: + diag = cublas.CUBLAS_DIAG_NON_UNIT + + if side: + blas_side = cublas.CUBLAS_SIDE_RIGHT + else: + blas_side = cublas.CUBLAS_SIDE_LEFT + + trsm( + cublas_handle, blas_side, uplo, + trans, diag, + m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) \ No newline at end of file diff --git a/src/serinv/utils/trsm_solve_host.py b/src/serinv/utils/trsm_solve_host.py new file mode 100644 index 00000000..820770e4 --- /dev/null +++ b/src/serinv/utils/trsm_solve_host.py @@ -0,0 +1,122 @@ +import numpy as np + + +from scipy.linalg.blas import get_blas_funcs +from scipy.linalg._misc import _datacopied +from scipy.linalg._decomp import _asarray_validated + +def solve_triangular_host(a, b, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, check_finite=True, side=0): + """ + Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. + + Parameters + ---------- + a : (M, M) array_like + A triangular matrix + b : (M,) or (M, N) array_like + Right-hand side matrix in ``a x = b`` + lower : bool, optional + Use only data contained in the lower triangle of `a`. + Default is to use upper triangle. + trans : {0, 1, 2, 'N', 'T', 'C'}, optional + Type of system to solve: + + ======== ========= + trans system + ======== ========= + 0 or 'N' a x = b + 1 or 'T' a^T x = b + 2 or 'C' a^H x = b + ======== ========= + unit_diagonal : bool, optional + If True, diagonal elements of `a` are assumed to be 1 and + will not be referenced. + overwrite_b : bool, optional + Allow overwriting data in `b` (may enhance performance) + check_finite : bool, optional + Whether to check that the input matrices contain only finite numbers. + Disabling may give a performance gain, but may result in problems + (crashes, non-termination) if the inputs do contain infinities or NaNs. + + Returns + ------- + x : (M,) or (M, N) ndarray + Solution to the system ``a x = b``. Shape of return matches `b`. + + Raises + ------ + LinAlgError + If `a` is singular + + Notes + ----- + .. versionadded:: 0.9.0 + + Examples + -------- + Solve the lower triangular system a x = b, where:: + + [3 0 0 0] [4] + a = [2 1 0 0] b = [2] + [1 0 1 0] [4] + [1 1 1 1] [2] + + >>> import numpy as np + >>> from scipy.linalg import solve_triangular + >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]]) + >>> b = np.array([4, 2, 4, 2]) + >>> x = solve_triangular(a, b, lower=True) + >>> x + array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333]) + >>> a.dot(x) # Check the result + array([ 4., 2., 4., 2.]) + + """ + + a1 = _asarray_validated(a, check_finite=check_finite) + b1 = _asarray_validated(b, check_finite=check_finite) + + if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: + raise ValueError('expected square matrix') + + if a1.shape[0] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + # accommodate empty arrays + if b1.size == 0: + dt_nonempty = solve_triangular_host( + np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) + ).dtype + return np.empty_like(b1, dtype=dt_nonempty) + + overwrite_b = overwrite_b or _datacopied(b1, b) + + x = _solve_triangular(a1, b1, trans, lower, unit_diagonal, overwrite_b, side) + return x + + +# solve_triangular without the input validation +def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, side=0): + + trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) + trsm, = get_blas_funcs(('trsm',), (a1, b1)) + + if a1.dtype.char in 'fd': + dtype = a1.dtype + else: + dtype = np.promote_types(a1.dtype.char, 'f') + + one = np.array(1, dtype=dtype) + alpha = one.ctypes.data + + if a1.flags.f_contiguous or trans == 2: + x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=lower, + trans_a=trans, diag=unit_diagonal, side=side) + else: + # transposed system is solved since trtrs expects Fortran ordering + x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, + trans_a=not trans, diag=unit_diagonal, side=side) + + return x From 7e3884e7a03cab3653f2d5f94f27cc0b79bd53bb Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:03:11 +0000 Subject: [PATCH 245/518] matmul implementation --- .../utils/{trsm_solve_host.py => matmul.py} | 51 ++-- src/serinv/utils/trsm.py | 246 ++++++++++++++++++ src/serinv/utils/trsm_solve_device.py | 102 -------- 3 files changed, 279 insertions(+), 120 deletions(-) rename src/serinv/utils/{trsm_solve_host.py => matmul.py} (73%) create mode 100644 src/serinv/utils/trsm.py delete mode 100644 src/serinv/utils/trsm_solve_device.py diff --git a/src/serinv/utils/trsm_solve_host.py b/src/serinv/utils/matmul.py similarity index 73% rename from src/serinv/utils/trsm_solve_host.py rename to src/serinv/utils/matmul.py index 820770e4..aa68ab31 100644 --- a/src/serinv/utils/trsm_solve_host.py +++ b/src/serinv/utils/matmul.py @@ -1,12 +1,31 @@ -import numpy as np +from serinv import _get_module_from_array +import numpy as np +from numpy.linalg import matmul from scipy.linalg.blas import get_blas_funcs from scipy.linalg._misc import _datacopied from scipy.linalg._decomp import _asarray_validated -def solve_triangular_host(a, b, trans=0, lower=False, unit_diagonal=False, - overwrite_b=False, check_finite=True, side=0): +try: + import cupy as cp + from cupy.cublas import gemm +except (ImportError, ImportWarning, ModuleNotFoundError): + pass + +def serinv_matmul (a, b): + """Wrapper to call GeMM for host or device""" + xp, la = _get_module_from_array(a) + + if xp == np: + return matmul(a, b) + elif xp == cp: + return gemm('N', 'N', a, b) + else: + ModuleNotFoundError("Unknown Module") + + +def matmul_gemm_host(a, b, trans_a=0, trans_b=0, overwrite_c=0, check_finite=False): """ Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. @@ -85,23 +104,21 @@ def solve_triangular_host(a, b, trans=0, lower=False, unit_diagonal=False, # accommodate empty arrays if b1.size == 0: - dt_nonempty = solve_triangular_host( + dt_nonempty = matmul_gemm_host( np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) ).dtype return np.empty_like(b1, dtype=dt_nonempty) - overwrite_b = overwrite_b or _datacopied(b1, b) - - x = _solve_triangular(a1, b1, trans, lower, unit_diagonal, overwrite_b, side) + x = _solve_triangular(a1, b1, trans_a, trans_b, overwrite_c) return x # solve_triangular without the input validation -def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, - overwrite_b=False, side=0): +def _solve_triangular(a1, b1, trans_a=0, trans_b=0, overwrite_c=0): - trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) - trsm, = get_blas_funcs(('trsm',), (a1, b1)) + trans_a = {'N': 0, 'T': 1, 'C': 2}.get(trans_a, trans_a) + trans_b = {'N': 0, 'T': 1, 'C': 2}.get(trans_b, trans_b) + gemm, = get_blas_funcs(('gemm',), (a1, b1)) if a1.dtype.char in 'fd': dtype = a1.dtype @@ -109,14 +126,12 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, dtype = np.promote_types(a1.dtype.char, 'f') one = np.array(1, dtype=dtype) + zero =np.array(0, dtype=dtype) alpha = one.ctypes.data + beta = zero.ctypes.data - if a1.flags.f_contiguous or trans == 2: - x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=lower, - trans_a=trans, diag=unit_diagonal, side=side) - else: - # transposed system is solved since trtrs expects Fortran ordering - x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, - trans_a=not trans, diag=unit_diagonal, side=side) + + x = gemm(alpha, a1.T, b1.T, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) + return x diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py new file mode 100644 index 00000000..159f32a4 --- /dev/null +++ b/src/serinv/utils/trsm.py @@ -0,0 +1,246 @@ +import numpy as np + +from serinv import _get_module_from_array + +from scipy.linalg.blas import get_blas_funcs +from scipy.linalg._misc import _datacopied +from scipy.linalg._decomp import _asarray_validated + +try: + import cupy as cp + from cupy.cuda import cublas + from cupy.cuda import device + from cupy.linalg import _util +except (ImportError, ImportWarning, ModuleNotFoundError): + pass + +def serinv_solve_triangular(a, b, trans=0, lower = False, unit_diagonal=False, + overwrite_b=False, check_finite=False, side=0): + """Wrapper for the trsm function to call depending on wheter the solve happens on the host or the device + + For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept + plus the side parameter which can either be 0 or 1 for left or right hand side + """ + xp = _get_module_from_array(a) + + if xp == np: + return solve_triangular_host(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) + elif xp == cp: + return solve_triangular_device(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) + else: + ModuleNotFoundError("Unknown Module") + + + +def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, check_finite=False, side=0): + """Solve the equation a x = b for x, assuming a is a triangular matrix. + + Args: + a (cupy.ndarray): The matrix with dimension ``(M, M)``. + b (cupy.ndarray): The matrix with dimension ``(M,)`` or + ``(M, N)``. + lower (bool): Use only data contained in the lower triangle of ``a``. + Default is to use upper triangle. + trans (0, 1, 2, 'N', 'T' or 'C'): Type of system to solve: + + - *'0'* or *'N'* -- :math:`a x = b` + - *'1'* or *'T'* -- :math:`a^T x = b` + - *'2'* or *'C'* -- :math:`a^H x = b` + + unit_diagonal (bool): If ``True``, diagonal elements of ``a`` are + assumed to be 1 and will not be referenced. + overwrite_b (bool): Allow overwriting data in b (may enhance + performance) + check_finite (bool): Whether to check that the input matrices contain + only finite numbers. Disabling may give a performance gain, but may + result in problems (crashes, non-termination) if the inputs do + contain infinities or NaNs. + + Returns: + cupy.ndarray: + The matrix with dimension ``(M,)`` or ``(M, N)``. + + .. seealso:: :func:`scipy.linalg.solve_triangular` + """ + + _util._assert_cupy_array(a, b) + + if len(a.shape) != 2 or a.shape[0] != a.shape[1]: + raise ValueError('expected square matrix') + if len(a) != len(b): + raise ValueError('incompatible dimensions') + + # Cast to float32 or float64 + if a.dtype.char in 'fd': + dtype = a.dtype + else: + dtype = np.promote_types(a.dtype.char, 'f') + + a = cp.array(a, dtype=dtype, order='F', copy=False) + b = cp.array(b, dtype=dtype, order='F', copy=(not overwrite_b)) + + if check_finite: + if a.dtype.kind == 'f' and not cp.isfinite(a).all(): + raise ValueError( + 'array must not contain infs or NaNs') + if b.dtype.kind == 'f' and not cp.isfinite(b).all(): + raise ValueError( + 'array must not contain infs or NaNs') + + m, n = (b.size, 1) if b.ndim == 1 else b.shape + cublas_handle = device.get_cublas_handle() + + if dtype == 'f': + trsm = cublas.strsm + elif dtype == 'd': + trsm = cublas.dtrsm + elif dtype == 'F': + trsm = cublas.ctrsm + else: # dtype == 'D' + trsm = cublas.ztrsm + one = np.array(1, dtype=dtype) + + if lower: + uplo = cublas.CUBLAS_FILL_MODE_LOWER + else: + uplo = cublas.CUBLAS_FILL_MODE_UPPER + + if trans == 'N': + trans = cublas.CUBLAS_OP_N + elif trans == 'T': + trans = cublas.CUBLAS_OP_T + elif trans == 'C': + trans = cublas.CUBLAS_OP_C + + if unit_diagonal: + diag = cublas.CUBLAS_DIAG_UNIT + else: + diag = cublas.CUBLAS_DIAG_NON_UNIT + + if side: + blas_side = cublas.CUBLAS_SIDE_RIGHT + else: + blas_side = cublas.CUBLAS_SIDE_LEFT + + trsm( + cublas_handle, blas_side, uplo, + trans, diag, + m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) + + +def solve_triangular_host(a, b, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, check_finite=True, side=0): + """ + Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. + + Parameters + ---------- + a : (M, M) array_like + A triangular matrix + b : (M,) or (M, N) array_like + Right-hand side matrix in ``a x = b`` + lower : bool, optional + Use only data contained in the lower triangle of `a`. + Default is to use upper triangle. + trans : {0, 1, 2, 'N', 'T', 'C'}, optional + Type of system to solve: + + ======== ========= + trans system + ======== ========= + 0 or 'N' a x = b + 1 or 'T' a^T x = b + 2 or 'C' a^H x = b + ======== ========= + unit_diagonal : bool, optional + If True, diagonal elements of `a` are assumed to be 1 and + will not be referenced. + overwrite_b : bool, optional + Allow overwriting data in `b` (may enhance performance) + check_finite : bool, optional + Whether to check that the input matrices contain only finite numbers. + Disabling may give a performance gain, but may result in problems + (crashes, non-termination) if the inputs do contain infinities or NaNs. + + Returns + ------- + x : (M,) or (M, N) ndarray + Solution to the system ``a x = b``. Shape of return matches `b`. + + Raises + ------ + LinAlgError + If `a` is singular + + Notes + ----- + .. versionadded:: 0.9.0 + + Examples + -------- + Solve the lower triangular system a x = b, where:: + + [3 0 0 0] [4] + a = [2 1 0 0] b = [2] + [1 0 1 0] [4] + [1 1 1 1] [2] + + >>> import numpy as np + >>> from scipy.linalg import solve_triangular + >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]]) + >>> b = np.array([4, 2, 4, 2]) + >>> x = solve_triangular(a, b, lower=True) + >>> x + array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333]) + >>> a.dot(x) # Check the result + array([ 4., 2., 4., 2.]) + + """ + + a1 = _asarray_validated(a, check_finite=check_finite) + b1 = _asarray_validated(b, check_finite=check_finite) + + if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: + raise ValueError('expected square matrix') + + if a1.shape[0] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + # accommodate empty arrays + if b1.size == 0: + dt_nonempty = solve_triangular_host( + np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) + ).dtype + return np.empty_like(b1, dtype=dt_nonempty) + + overwrite_b = overwrite_b or _datacopied(b1, b) + + x = _solve_triangular(a1, b1, trans, lower, unit_diagonal, overwrite_b, side) + return x + + +# solve_triangular without the input validation +def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, side=0): + + trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) + trsm, = get_blas_funcs(('trsm',), (a1, b1)) + + if a1.dtype.char in 'fd': + dtype = a1.dtype + else: + dtype = np.promote_types(a1.dtype.char, 'f') + + one = np.array(1, dtype=dtype) + alpha = one.ctypes.data + + if a1.flags.f_contiguous or trans == 2: + x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=lower, + trans_a=trans, diag=unit_diagonal, side=side) + else: + # transposed system is solved since trtrs expects Fortran ordering + x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, + trans_a=not trans, diag=unit_diagonal, side=side) + + return x \ No newline at end of file diff --git a/src/serinv/utils/trsm_solve_device.py b/src/serinv/utils/trsm_solve_device.py deleted file mode 100644 index 9d6bdb43..00000000 --- a/src/serinv/utils/trsm_solve_device.py +++ /dev/null @@ -1,102 +0,0 @@ -import numpy - -from cupy.cuda import cublas -from cupy.cuda import device -from cupy.linalg import _util - - -def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, - overwrite_b=False, check_finite=False, aplha = 1., side=0): - """Solve the equation a x = b for x, assuming a is a triangular matrix. - - Args: - a (cupy.ndarray): The matrix with dimension ``(M, M)``. - b (cupy.ndarray): The matrix with dimension ``(M,)`` or - ``(M, N)``. - lower (bool): Use only data contained in the lower triangle of ``a``. - Default is to use upper triangle. - trans (0, 1, 2, 'N', 'T' or 'C'): Type of system to solve: - - - *'0'* or *'N'* -- :math:`a x = b` - - *'1'* or *'T'* -- :math:`a^T x = b` - - *'2'* or *'C'* -- :math:`a^H x = b` - - unit_diagonal (bool): If ``True``, diagonal elements of ``a`` are - assumed to be 1 and will not be referenced. - overwrite_b (bool): Allow overwriting data in b (may enhance - performance) - check_finite (bool): Whether to check that the input matrices contain - only finite numbers. Disabling may give a performance gain, but may - result in problems (crashes, non-termination) if the inputs do - contain infinities or NaNs. - - Returns: - cupy.ndarray: - The matrix with dimension ``(M,)`` or ``(M, N)``. - - .. seealso:: :func:`scipy.linalg.solve_triangular` - """ - - _util._assert_cupy_array(a, b) - - if len(a.shape) != 2 or a.shape[0] != a.shape[1]: - raise ValueError('expected square matrix') - if len(a) != len(b): - raise ValueError('incompatible dimensions') - - # Cast to float32 or float64 - if a.dtype.char in 'fd': - dtype = a.dtype - else: - dtype = numpy.promote_types(a.dtype.char, 'f') - - a = cupy.array(a, dtype=dtype, order='F', copy=False) - b = cupy.array(b, dtype=dtype, order='F', copy=(not overwrite_b)) - - if check_finite: - if a.dtype.kind == 'f' and not cupy.isfinite(a).all(): - raise ValueError( - 'array must not contain infs or NaNs') - if b.dtype.kind == 'f' and not cupy.isfinite(b).all(): - raise ValueError( - 'array must not contain infs or NaNs') - - m, n = (b.size, 1) if b.ndim == 1 else b.shape - cublas_handle = device.get_cublas_handle() - - if dtype == 'f': - trsm = cublas.strsm - elif dtype == 'd': - trsm = cublas.dtrsm - elif dtype == 'F': - trsm = cublas.ctrsm - else: # dtype == 'D' - trsm = cublas.ztrsm - one = numpy.array(1, dtype=dtype) - - if lower: - uplo = cublas.CUBLAS_FILL_MODE_LOWER - else: - uplo = cublas.CUBLAS_FILL_MODE_UPPER - - if trans == 'N': - trans = cublas.CUBLAS_OP_N - elif trans == 'T': - trans = cublas.CUBLAS_OP_T - elif trans == 'C': - trans = cublas.CUBLAS_OP_C - - if unit_diagonal: - diag = cublas.CUBLAS_DIAG_UNIT - else: - diag = cublas.CUBLAS_DIAG_NON_UNIT - - if side: - blas_side = cublas.CUBLAS_SIDE_RIGHT - else: - blas_side = cublas.CUBLAS_SIDE_LEFT - - trsm( - cublas_handle, blas_side, uplo, - trans, diag, - m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) \ No newline at end of file From 08a768702b132cbc87bdfdb18796a427bbb5a348 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:24:41 +0000 Subject: [PATCH 246/518] used serinv_matmul once for testing --- src/serinv/algs/pobtas.py | 6 ++++-- src/serinv/utils/__init__.py | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index bab2a911..13fa4628 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -4,8 +4,10 @@ from serinv import ( ArrayLike, _get_module_from_array, + ) +from serinv.utils import serinv_matmul, serinv_solve_triangular def pobtas( L_diagonal_blocks: ArrayLike, @@ -86,9 +88,9 @@ def _pobtas( lower=True, ) - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= serinv_matmul( L_lower_diagonal_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + , B[i * diag_blocksize : (i + 1) * diag_blocksize] ) B[-arrow_blocksize:] -= ( diff --git a/src/serinv/utils/__init__.py b/src/serinv/utils/__init__.py index 4079f81c..00a5a351 100644 --- a/src/serinv/utils/__init__.py +++ b/src/serinv/utils/__init__.py @@ -8,6 +8,9 @@ from serinv.utils.pobtx import allocate_pobtx_permutation_buffers from serinv.utils.pobtax import allocate_pobtax_permutation_buffers +from serinv.utils.matmul import serinv_matmul +from serinv.utils.trsm import serinv_solve_triangular + __all__ = [ "check_block_dd", "check_ddbta", @@ -15,4 +18,6 @@ "allocate_ddbtax_permutation_buffers", "allocate_pobtx_permutation_buffers", "allocate_pobtax_permutation_buffers", + "serinv_matmul", + "serinv_solve_triangluar" ] From dfef1722d5724c748500fb65a17877104b84afd8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:26:52 +0000 Subject: [PATCH 247/518] used serinv_solve_triangular once for testing --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 13fa4628..72a4c751 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -82,10 +82,10 @@ def _pobtas( if trans == "N": # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( + B[i * diag_blocksize : (i + 1) * diag_blocksize] = serinv_solve_triangular( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, + lower=True, side=0 ) B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= serinv_matmul( From 0b9ceb7998929dc24ed01429fb6b5e5f33cc9039 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:41:41 +0000 Subject: [PATCH 248/518] updated solve_triangular_deive to newer cupy implementation --- src/serinv/utils/trsm.py | 142 ++++++++++++++++++++++++++++++--------- 1 file changed, 111 insertions(+), 31 deletions(-) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index 159f32a4..a2fac8d5 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -66,10 +66,23 @@ def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, _util._assert_cupy_array(a, b) - if len(a.shape) != 2 or a.shape[0] != a.shape[1]: - raise ValueError('expected square matrix') - if len(a) != len(b): - raise ValueError('incompatible dimensions') + if a.ndim == 2: + if a.shape[0] != a.shape[1]: + raise ValueError('expected square matrix') + if len(a) != len(b): + raise ValueError('incompatible dimensions') + batch_count = 0 + elif a.ndim > 2: + if a.shape[-1] != a.shape[-2]: + raise ValueError('expected a batch of square matrices') + if a.shape[:-2] != b.shape[:a.ndim - 2]: + raise ValueError('incompatible batch count') + if b.ndim < a.ndim - 1 or a.shape[-2] != b.shape[a.ndim - 2]: + raise ValueError('incompatible dimensions') + batch_count = math.prod(a.shape[:-2]) + else: + raise ValueError( + 'expected one square matrix or a batch of square matrices') # Cast to float32 or float64 if a.dtype.char in 'fd': @@ -77,9 +90,6 @@ def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, else: dtype = np.promote_types(a.dtype.char, 'f') - a = cp.array(a, dtype=dtype, order='F', copy=False) - b = cp.array(b, dtype=dtype, order='F', copy=(not overwrite_b)) - if check_finite: if a.dtype.kind == 'f' and not cp.isfinite(a).all(): raise ValueError( @@ -88,17 +98,69 @@ def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, raise ValueError( 'array must not contain infs or NaNs') - m, n = (b.size, 1) if b.ndim == 1 else b.shape - cublas_handle = device.get_cublas_handle() + if batch_count: + m, n = b.shape[-2:] if b.ndim == a.ndim else (b.shape[-1], 1) + + a_new_shape = (batch_count, m, m) + b_shape = b.shape + b_data_ptr = b.data.ptr + # trsm receives Fortran array, but we want zero copy + if trans == 'N' or trans == cublas.CUBLAS_OP_N: + # normal Fortran upper == transpose C lower + trans = cublas.CUBLAS_OP_T + lower = not lower + a = cp.ascontiguousarray(a.reshape(*a_new_shape), dtype=dtype) + elif trans == 'T' or trans == cublas.CUBLAS_OP_T: + # transpose Fortran upper == normal C lower + trans = cublas.CUBLAS_OP_N + lower = not lower + a = cp.ascontiguousarray(a.reshape(*a_new_shape), dtype=dtype) + elif trans == 'C' or trans == cublas.CUBLAS_OP_C: + if dtype == 'f' or dtype == 'd': + # real numbers + # Hermitian Fortran upper == transpose Fortran upper + # == normal C lower + trans = cublas.CUBLAS_OP_N + lower = not lower + a = cp.ascontiguousarray(a.reshape(*a_new_shape), + dtype=dtype) + else: + # complex numbers + trans = cublas.CUBLAS_OP_C + a = cp.ascontiguousarray( + a.reshape(*a_new_shape).transpose(0, 2, 1), dtype=dtype) + else: # know nothing about `trans`, just convert C to Fortran + a = cp.ascontiguousarray( + a.reshape(*a_new_shape).transpose(0, 2, 1), dtype=dtype) + b = cp.ascontiguousarray( + b.reshape(batch_count, m, n).transpose(0, 2, 1), dtype=dtype) + if b.data.ptr == b_data_ptr and not overwrite_b: + b = b.copy() + + start = a.data.ptr + step = m * m * a.itemsize + stop = start + step * batch_count + a_array = cp.arange(start, stop, step, dtype=cp.uintp) + + start = b.data.ptr + step = m * n * b.itemsize + stop = start + step * batch_count + b_array = cp.arange(start, stop, step, dtype=cp.uintp) + else: + a = cp.array(a, dtype=dtype, order='F', copy=None) + b = cp.array(b, dtype=dtype, order='F', + copy=(None if overwrite_b else True)) + + m, n = (b.size, 1) if b.ndim == 1 else b.shape - if dtype == 'f': - trsm = cublas.strsm - elif dtype == 'd': - trsm = cublas.dtrsm - elif dtype == 'F': - trsm = cublas.ctrsm - else: # dtype == 'D' - trsm = cublas.ztrsm + if trans == 'N': + trans = cublas.CUBLAS_OP_N + elif trans == 'T': + trans = cublas.CUBLAS_OP_T + elif trans == 'C': + trans = cublas.CUBLAS_OP_C + + cublas_handle = device.get_cublas_handle() one = np.array(1, dtype=dtype) if lower: @@ -106,27 +168,45 @@ def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, else: uplo = cublas.CUBLAS_FILL_MODE_UPPER - if trans == 'N': - trans = cublas.CUBLAS_OP_N - elif trans == 'T': - trans = cublas.CUBLAS_OP_T - elif trans == 'C': - trans = cublas.CUBLAS_OP_C - if unit_diagonal: diag = cublas.CUBLAS_DIAG_UNIT else: diag = cublas.CUBLAS_DIAG_NON_UNIT if side: - blas_side = cublas.CUBLAS_SIDE_RIGHT + side = cublas.CUBLAS_SIDE_RIGHT else: - blas_side = cublas.CUBLAS_SIDE_LEFT - - trsm( - cublas_handle, blas_side, uplo, - trans, diag, - m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) + side = cublas.CUBLAS_SIDE_LEFT + + if batch_count: + if dtype == 'f': + trsm = cublas.strsmBatched + elif dtype == 'd': + trsm = cublas.dtrsmBatched + elif dtype == 'F': + trsm = cublas.ctrsmBatched + else: # dtype == 'D' + trsm = cublas.ztrsmBatched + trsm( + cublas_handle, side, uplo, + trans, diag, + m, n, one.ctypes.data, a_array.data.ptr, m, + b_array.data.ptr, m, batch_count) + return b.transpose(0, 2, 1).reshape(b_shape) + else: + if dtype == 'f': + trsm = cublas.strsm + elif dtype == 'd': + trsm = cublas.dtrsm + elif dtype == 'F': + trsm = cublas.ctrsm + else: # dtype == 'D' + trsm = cublas.ztrsm + trsm( + cublas_handle, side, uplo, + trans, diag, + m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) + return b def solve_triangular_host(a, b, trans=0, lower=False, unit_diagonal=False, From b073708a49e2937ccdf10b7d2ad93605ac8c2e65 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:42:51 +0000 Subject: [PATCH 249/518] reomved serinv_solve_triangular --- src/serinv/algs/pobtas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 72a4c751..67d21328 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -82,10 +82,10 @@ def _pobtas( if trans == "N": # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = serinv_solve_triangular( + B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, side=0 + lower=True ) B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= serinv_matmul( From 0ec3284bffbe75cbd0f09b5aa3fe0767f5290f28 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:45:41 +0000 Subject: [PATCH 250/518] debug messages --- src/serinv/algs/pobtas.py | 4 ++-- src/serinv/utils/trsm.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 67d21328..72a4c751 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -82,10 +82,10 @@ def _pobtas( if trans == "N": # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( + B[i * diag_blocksize : (i + 1) * diag_blocksize] = serinv_solve_triangular( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True + lower=True, side=0 ) B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= serinv_matmul( diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index a2fac8d5..c0da9e20 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -21,11 +21,14 @@ def serinv_solve_triangular(a, b, trans=0, lower = False, unit_diagonal=False, For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept plus the side parameter which can either be 0 or 1 for left or right hand side """ + print("one") xp = _get_module_from_array(a) - + print("two") if xp == np: + print("three") return solve_triangular_host(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) elif xp == cp: + print("four") return solve_triangular_device(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) else: ModuleNotFoundError("Unknown Module") From 176577000454eaed0019d882934bdfd17eee0074 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:48:24 +0000 Subject: [PATCH 251/518] print xp --- src/serinv/utils/trsm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index c0da9e20..550540cd 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -23,6 +23,7 @@ def serinv_solve_triangular(a, b, trans=0, lower = False, unit_diagonal=False, """ print("one") xp = _get_module_from_array(a) + print(xp) print("two") if xp == np: print("three") From 3c56e0900aac05146e91a326ce9b5577a5591555 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:49:21 +0000 Subject: [PATCH 252/518] fixed module tuple --- src/serinv/utils/trsm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index 550540cd..6190075b 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -22,7 +22,7 @@ def serinv_solve_triangular(a, b, trans=0, lower = False, unit_diagonal=False, plus the side parameter which can either be 0 or 1 for left or right hand side """ print("one") - xp = _get_module_from_array(a) + xp, la = _get_module_from_array(a) print(xp) print("two") if xp == np: From bc1d67fa91e12d1c15dcc4f7e891610ec9b8be54 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:52:47 +0000 Subject: [PATCH 253/518] removed transpose --- src/serinv/utils/trsm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index 6190075b..2c218fd2 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -324,7 +324,7 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, trans_a=trans, diag=unit_diagonal, side=side) else: # transposed system is solved since trtrs expects Fortran ordering - x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, + x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=not lower, trans_a=not trans, diag=unit_diagonal, side=side) return x \ No newline at end of file From d756ab9d48aacbd4d7d4615c2f8d97c406325f91 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 08:59:10 +0000 Subject: [PATCH 254/518] print trsm func --- src/serinv/utils/trsm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index 2c218fd2..334e31bd 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -310,6 +310,7 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) trsm, = get_blas_funcs(('trsm',), (a1, b1)) + print(trsm) if a1.dtype.char in 'fd': dtype = a1.dtype @@ -324,7 +325,7 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, trans_a=trans, diag=unit_diagonal, side=side) else: # transposed system is solved since trtrs expects Fortran ordering - x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=not lower, + x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, trans_a=not trans, diag=unit_diagonal, side=side) return x \ No newline at end of file From cf79d8d8aa48e38293cf0f5f14ef644ab639bc46 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 09:11:45 +0000 Subject: [PATCH 255/518] changed print --- src/serinv/utils/trsm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index 334e31bd..ae310acb 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -21,10 +21,10 @@ def serinv_solve_triangular(a, b, trans=0, lower = False, unit_diagonal=False, For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept plus the side parameter which can either be 0 or 1 for left or right hand side """ - print("one") + print(a) xp, la = _get_module_from_array(a) print(xp) - print("two") + print(b) if xp == np: print("three") return solve_triangular_host(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) From 4554d592f2822aca14472564a6fa1681e92a8769 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 09:31:16 +0000 Subject: [PATCH 256/518] changed alpha to 1 --- src/serinv/utils/trsm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index ae310acb..64ed8578 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -318,7 +318,7 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, dtype = np.promote_types(a1.dtype.char, 'f') one = np.array(1, dtype=dtype) - alpha = one.ctypes.data + alpha = 1 if a1.flags.f_contiguous or trans == 2: x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=lower, From 5a84dea9619938e2598b8d3d976b8a98cba1a848 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 09:32:02 +0000 Subject: [PATCH 257/518] removed one --- src/serinv/utils/trsm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/serinv/utils/trsm.py b/src/serinv/utils/trsm.py index 64ed8578..c8214a9b 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/utils/trsm.py @@ -317,7 +317,6 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, else: dtype = np.promote_types(a1.dtype.char, 'f') - one = np.array(1, dtype=dtype) alpha = 1 if a1.flags.f_contiguous or trans == 2: From a7d4c183896c1018927b1594fc6e611b14c983dc Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 10:27:31 +0000 Subject: [PATCH 258/518] changed matmul host to own implementation --- src/serinv/utils/matmul.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/serinv/utils/matmul.py b/src/serinv/utils/matmul.py index aa68ab31..7febd69e 100644 --- a/src/serinv/utils/matmul.py +++ b/src/serinv/utils/matmul.py @@ -18,7 +18,7 @@ def serinv_matmul (a, b): xp, la = _get_module_from_array(a) if xp == np: - return matmul(a, b) + return matmul_gemm_host(a, b) elif xp == cp: return gemm('N', 'N', a, b) else: @@ -125,11 +125,8 @@ def _solve_triangular(a1, b1, trans_a=0, trans_b=0, overwrite_c=0): else: dtype = np.promote_types(a1.dtype.char, 'f') - one = np.array(1, dtype=dtype) - zero =np.array(0, dtype=dtype) - alpha = one.ctypes.data - beta = zero.ctypes.data - + alpha = 1 + beta = 0 x = gemm(alpha, a1.T, b1.T, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) From 3e92f8960d21a55ad22b4b738c4fc4a1ac44faa6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 10:28:32 +0000 Subject: [PATCH 259/518] changed array ordering --- src/serinv/utils/matmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/utils/matmul.py b/src/serinv/utils/matmul.py index 7febd69e..e67b4fdb 100644 --- a/src/serinv/utils/matmul.py +++ b/src/serinv/utils/matmul.py @@ -128,7 +128,7 @@ def _solve_triangular(a1, b1, trans_a=0, trans_b=0, overwrite_c=0): alpha = 1 beta = 0 - x = gemm(alpha, a1.T, b1.T, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) + x = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) return x From 372ea0ec405dbf9862398d71268e238638a09882 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 10:29:38 +0000 Subject: [PATCH 260/518] changed name of matmul function --- src/serinv/utils/matmul.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/utils/matmul.py b/src/serinv/utils/matmul.py index e67b4fdb..b8ccb241 100644 --- a/src/serinv/utils/matmul.py +++ b/src/serinv/utils/matmul.py @@ -109,12 +109,12 @@ def matmul_gemm_host(a, b, trans_a=0, trans_b=0, overwrite_c=0, check_finite=Fal ).dtype return np.empty_like(b1, dtype=dt_nonempty) - x = _solve_triangular(a1, b1, trans_a, trans_b, overwrite_c) + x = _matmul_gemm(a1, b1, trans_a, trans_b, overwrite_c) return x # solve_triangular without the input validation -def _solve_triangular(a1, b1, trans_a=0, trans_b=0, overwrite_c=0): +def _matmul_gemm(a1, b1, trans_a=0, trans_b=0, overwrite_c=0): trans_a = {'N': 0, 'T': 1, 'C': 2}.get(trans_a, trans_a) trans_b = {'N': 0, 'T': 1, 'C': 2}.get(trans_b, trans_b) From cc5138e448edab9e2ccc452fe6c1465527f6ee1e Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 4 Jun 2025 10:42:11 +0000 Subject: [PATCH 261/518] expose trans param for matmul --- src/serinv/utils/matmul.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/utils/matmul.py b/src/serinv/utils/matmul.py index b8ccb241..a2cc7d27 100644 --- a/src/serinv/utils/matmul.py +++ b/src/serinv/utils/matmul.py @@ -13,14 +13,14 @@ except (ImportError, ImportWarning, ModuleNotFoundError): pass -def serinv_matmul (a, b): +def serinv_matmul (a, b, trans_a = 'N', trans_b = 'N'): """Wrapper to call GeMM for host or device""" xp, la = _get_module_from_array(a) if xp == np: - return matmul_gemm_host(a, b) + return matmul_gemm_host(a, b, trans_a=trans_a, trans_b=trans_b) elif xp == cp: - return gemm('N', 'N', a, b) + return gemm(trans_a, trans_b, a, b) else: ModuleNotFoundError("Unknown Module") From 96fb56beb604c77a209c79095baaebe439b5fa7f Mon Sep 17 00:00:00 2001 From: vincent-maillou Date: Thu, 5 Jun 2025 15:18:36 +0200 Subject: [PATCH 262/518] unified (and added) test streaming for pobtaf/si --- tests/conftest.py | 3 --- tests/tests_algs/regular/tests_bt/test_pobtf.py | 12 ++++++++++++ tests/tests_algs/regular/tests_bt/test_pobts.py | 16 ++++++++-------- tests/tests_algs/regular/tests_bt/test_pobtsi.py | 12 ++++++++++++ 4 files changed, 32 insertions(+), 11 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 3e624933..4d15c7db 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. # Global pytest fixtures for the Serinv tests. - import pytest from serinv import backend_flags @@ -15,7 +14,6 @@ ] ) - DTYPE = [ pytest.param("float64", id="float64"), pytest.param("complex128", id="complex128"), @@ -26,7 +24,6 @@ pytest.param(3, id="diagonal_blocksize=3"), ] - @pytest.fixture(params=ARRAY_TYPE, autouse=True) def array_type(request: pytest.FixtureRequest) -> str: return request.param diff --git a/tests/tests_algs/regular/tests_bt/test_pobtf.py b/tests/tests_algs/regular/tests_bt/test_pobtf.py index d1969b05..0ac3ae89 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobtf.py +++ b/tests/tests_algs/regular/tests_bt/test_pobtf.py @@ -3,6 +3,8 @@ import numpy as np import pytest +from ....conftest import ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize @@ -11,6 +13,16 @@ if backend_flags["cupy_avail"]: import cupyx as cpx + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + + @pytest.fixture(params=ARRAY_TYPE, autouse=True) + def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() def test_pobtf( diff --git a/tests/tests_algs/regular/tests_bt/test_pobts.py b/tests/tests_algs/regular/tests_bt/test_pobts.py index d137c796..9011caa1 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobts.py +++ b/tests/tests_algs/regular/tests_bt/test_pobts.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from ....conftest import ARRAY_TYPE as ARRAY_TYPE +from ....conftest import ARRAY_TYPE from serinv import backend_flags, _get_module_from_array from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize, rhs @@ -11,20 +11,19 @@ from serinv.algs import pobtf, pobts if backend_flags["cupy_avail"]: + import cupyx as cpx + ARRAY_TYPE.extend( [ pytest.param("streaming", id="streaming"), ] ) -if backend_flags["cupy_avail"]: - import cupyx as cpx + @pytest.fixture(params=ARRAY_TYPE, autouse=True) + def array_type(request: pytest.FixtureRequest) -> str: + return request.param -@pytest.fixture(params=ARRAY_TYPE, autouse=True) -def array_type(request: pytest.FixtureRequest) -> str: - return request.param - @pytest.mark.mpi_skip() @pytest.mark.parametrize("n_rhs", [1, 2, 3]) def test_pobts( @@ -34,7 +33,7 @@ def test_pobts( array_type: str, dtype: np.dtype, ): - + A = dd_bt( diagonal_blocksize, n_diag_blocks, @@ -79,6 +78,7 @@ def test_pobts( pobtf( A_diagonal_blocks, A_lower_diagonal_blocks, + device_streaming=True if array_type == "streaming" else False, ) # Forward solve: Y=L^{-1}B diff --git a/tests/tests_algs/regular/tests_bt/test_pobtsi.py b/tests/tests_algs/regular/tests_bt/test_pobtsi.py index 22ec1d3a..5463a7b9 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobtsi.py +++ b/tests/tests_algs/regular/tests_bt/test_pobtsi.py @@ -3,6 +3,8 @@ import numpy as np import pytest +from ....conftest import ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize @@ -11,6 +13,16 @@ if backend_flags["cupy_avail"]: import cupyx as cpx + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + + @pytest.fixture(params=ARRAY_TYPE, autouse=True) + def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() def test_pobtsi( From 33ba4677c15ad94b6759b47ea56917a6275216c2 Mon Sep 17 00:00:00 2001 From: vincent-maillou Date: Thu, 5 Jun 2025 15:24:18 +0200 Subject: [PATCH 263/518] just ran `black .` --- src/serinv/algs/pobtas.py | 230 ++++++++++-------- src/serinv/algs/pobts.py | 130 +++++----- src/serinv/wrappers/ddbtars.py | 1 - src/serinv/wrappers/pddbtasc.py | 2 +- src/serinv/wrappers/pddbtasci.py | 2 +- src/serinv/wrappers/pddbtsc.py | 4 +- src/serinv/wrappers/pddbtsci.py | 2 +- tests/conftest.py | 1 + .../permuted/test_bt/test_pobts_permuted.py | 4 +- .../regular/tests_bta/test_pobtaf.py | 1 + .../regular/tests_bta/test_pobtas.py | 3 +- 11 files changed, 207 insertions(+), 173 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 1575a4ac..cc51e3bb 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -223,7 +223,8 @@ def _pobtas_permuted( ) else: raise ValueError(f"Invalid transpose argument: {trans}.") - + + def _pobtas_streaming( L_diagonal_blocks: ArrayLike, L_lower_diagonal_blocks: ArrayLike, @@ -238,8 +239,6 @@ def _pobtas_streaming( raise NotImplementedError( "Host<->Device streaming only works when host-arrays are given." ) - - cp, cu_la = _get_module_from_str(module_str="cupy") @@ -253,18 +252,13 @@ def _pobtas_streaming( h2d_stream = cp.cuda.Stream(non_blocking=True) d2h_stream = cp.cuda.Stream(non_blocking=True) - - # Device Buffers # B Buffers - B_shape = B[-arrow_blocksize:] # block template + B_shape = B[-arrow_blocksize:] # block template B_arrow_tip_d = cp.empty_like(B_shape) - B_shape = B[0 : diag_blocksize] - B_d = cp.empty( - (2, *B_shape.shape), dtype=B_shape.dtype - ) - + B_shape = B[0:diag_blocksize] + B_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype) # L Buffers L_diagonal_blocks_d = cp.empty( @@ -307,9 +301,9 @@ def _pobtas_streaming( L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:], stream=h2d_stream) # --- H2D: transfers --- - B_d[0].set(arr=B[0 : diag_blocksize], stream = h2d_stream) + B_d[0].set(arr=B[0:diag_blocksize], stream=h2d_stream) h2d_B_events[0].record(stream=h2d_stream) - + L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream) h2d_diagonal_events[0].record(stream=h2d_stream) @@ -318,39 +312,36 @@ def _pobtas_streaming( # --- D2H: event --- d2h_B_events[1].record(stream=d2h_stream) - + n_diag_blocks: int = L_diagonal_blocks.shape[0] if n_diag_blocks > 1: - L_lower_diagonal_blocks_d[0].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) + L_lower_diagonal_blocks_d[0].set( + arr=L_lower_diagonal_blocks[0], stream=h2d_stream + ) h2d_lower_diagonal_events[0].record(stream=h2d_stream) - # --- Computations --- for i in range(0, n_diag_blocks - 1): # pass next B block h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) - B_d[(i + 1) % 2].set( arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], - stream = h2d_stream + stream=h2d_stream, ) h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) - if i + 1 < n_diag_blocks - 1: # pass next diagonal block h2d_stream.wait_event(compute_current_B_events[(i + 1) % 2]) L_diagonal_blocks_d[(i + 1) % 2].set( - arr=L_diagonal_blocks[i + 1], - stream=h2d_stream + arr=L_diagonal_blocks[i + 1], stream=h2d_stream ) - - h2d_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) + h2d_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) with compute_stream: # Solve current B block @@ -363,22 +354,20 @@ def _pobtas_streaming( ) compute_current_B_events[i % 2].record(stream=compute_stream) - + # Pass current B block back - if i + 1 < n_diag_blocks - 1: # Pass next lower diagonal block h2d_stream.wait_event(compute_next_B_events[(i + 1) % 2]) L_lower_diagonal_blocks_d[(i + 1) % 2].set( - arr=L_lower_diagonal_blocks[i + 1], - stream=h2d_stream + arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream ) - + h2d_lower_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) d2h_stream.wait_event(compute_current_B_events[i % 2]) - d2h_stream.wait_event(h2d_lower_diagonal_events[(i+1) % 2]) + d2h_stream.wait_event(h2d_lower_diagonal_events[(i + 1) % 2]) B_d[i % 2].get( out=B[i * diag_blocksize : (i + 1) * diag_blocksize], @@ -387,131 +376,145 @@ def _pobtas_streaming( ) d2h_B_events[i % 2].record(stream=d2h_stream) - + with compute_stream: # Update next B block compute_stream.wait_event(h2d_B_events[(i + 1) % 2]) - B_d[(i + 1) % 2] -= ( - L_lower_diagonal_blocks_d[i % 2] - @ B_d[i % 2] - ) + B_d[(i + 1) % 2] -= L_lower_diagonal_blocks_d[i % 2] @ B_d[i % 2] + + compute_next_B_events[i % 2].record(stream=compute_stream) - compute_next_B_events[i % 2].record(stream=compute_stream) - if i + 1 < n_diag_blocks - 1: # Pass next lower arrow block h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) L_lower_arrow_blocks_d[(i + 1) % 2].set( - arr=L_lower_arrow_blocks[i + 1], - stream=h2d_stream + arr=L_lower_arrow_blocks[i + 1], stream=h2d_stream ) - + h2d_arrow_events[(i + 1) % 2].record(stream=h2d_stream) with compute_stream: # Update arrow tip compute_stream.wait_event(h2d_arrow_events[i % 2]) - - B_arrow_tip_d -= ( - L_lower_arrow_blocks_d[i % 2] - @ B_d[i % 2] - ) + + B_arrow_tip_d -= L_lower_arrow_blocks_d[i % 2] @ B_d[i % 2] compute_arrow_B_events[i % 2].record(stream=compute_stream) # Pass arrow tip back d2h_stream.wait_event(compute_arrow_B_events[n_diag_blocks % 2]) - - B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - d2h_tip_events[n_diag_blocks % 2].record(stream=d2h_stream) + B_arrow_tip_d.get( + out=B[-arrow_blocksize:], + stream=d2h_stream, + blocking=False, + ) + d2h_tip_events[n_diag_blocks % 2].record(stream=d2h_stream) if not partial: # Pass last blocks h2d_stream.wait_event(d2h_tip_events[n_diag_blocks % 2]) - L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream) - + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream + ) + h2d_diagonal_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) - + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_lower_arrow_blocks[-1], stream=h2d_stream + ) + h2d_arrow_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) - with compute_stream: # Solve last B block compute_stream.wait_event(h2d_diagonal_events[(n_diag_blocks - 1) % 2]) B_d[(n_diag_blocks - 1) % 2] = cu_la.solve_triangular( - L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], - B_d[(n_diag_blocks - 1) % 2], - lower=True + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2], + lower=True, ) - + compute_partial_events[0].record(stream=compute_stream) # Pass last B block back d2h_stream.wait_event(compute_partial_events[0]) B_d[(n_diag_blocks - 1) % 2].get( - out=B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize], - stream=d2h_stream, - blocking=False + out=B[ + (n_diag_blocks - 1) + * diag_blocksize : n_diag_blocks + * diag_blocksize + ], + stream=d2h_stream, + blocking=False, ) - + d2h_B_events[0].record(stream=d2h_stream) with compute_stream: # Solve arrow tip compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) - B_arrow_tip_d -= (L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2] @ B_d[(n_diag_blocks - 1) % 2]) - B_arrow_tip_d = cu_la.solve_triangular(L_arrow_tip_block_d, B_arrow_tip_d, lower=True) + B_arrow_tip_d -= ( + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2] + @ B_d[(n_diag_blocks - 1) % 2] + ) + B_arrow_tip_d = cu_la.solve_triangular( + L_arrow_tip_block_d, B_arrow_tip_d, lower=True + ) compute_partial_events[1].record(stream=compute_stream) d2h_stream.wait_event(compute_partial_events[1]) - B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) + B_arrow_tip_d.get( + out=B[-arrow_blocksize:], + stream=d2h_stream, + blocking=False, + ) elif trans == "T" or trans == "C": # ----- Backward substitution ----- # Buffers - B_previous_d = cp.empty( - (2, *B_shape.shape), dtype=B_shape.dtype - ) + B_previous_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype) # Delete helper variable del B_shape - + # Events compute_B_events = [cp.cuda.Event(), cp.cuda.Event()] h2d_events = [cp.cuda.Event(), cp.cuda.Event()] d2h_events = [cp.cuda.Event(), cp.cuda.Event()] - + # --- H2D: transfers --- B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:], stream=h2d_stream) B_d[(n_diag_blocks - 1) % 2].set( - arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], - stream=h2d_stream + arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], + stream=h2d_stream, + ) + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_diagonal_blocks[-1], stream=h2d_stream + ) + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_lower_arrow_blocks[-1], stream=h2d_stream ) - L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[-1], stream=h2d_stream) - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_lower_arrow_blocks[-1], stream=h2d_stream) h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) - + # ----- Backward substitution ----- if not partial: - + with compute_stream: # X_{ndb+1} = L_{ndb+1,ndb+1}^{-T} (Y_{ndb+1}) compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) - B_arrow_tip_d = cu_la.solve_triangular( + B_arrow_tip_d = cu_la.solve_triangular( L_arrow_tip_block_d, B_arrow_tip_d, lower=True, @@ -519,14 +522,13 @@ def _pobtas_streaming( ) # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) - B_previous_d[(n_diag_blocks - 1) % 2] = ( - cu_la.solve_triangular( - L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], - B_d[(n_diag_blocks - 1) % 2] - - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].conj().T @ B_arrow_tip_d, - lower=True, - trans="C", - ) + B_previous_d[(n_diag_blocks - 1) % 2] = cu_la.solve_triangular( + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2] + - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].conj().T + @ B_arrow_tip_d, + lower=True, + trans="C", ) compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) @@ -534,39 +536,61 @@ def _pobtas_streaming( # Pass arrow tip back d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) - B_arrow_tip_d.get(out=B[-arrow_blocksize:], stream=d2h_stream, blocking=False,) - + B_arrow_tip_d.get( + out=B[-arrow_blocksize:], + stream=d2h_stream, + blocking=False, + ) if n_diag_blocks > 1: B_d[n_diag_blocks % 2].set( - arr=B[-arrow_blocksize - (2 * diag_blocksize) : -arrow_blocksize - diag_blocksize], - stream=h2d_stream + arr=B[ + -arrow_blocksize + - (2 * diag_blocksize) : -arrow_blocksize + - diag_blocksize + ], + stream=h2d_stream, + ) + L_diagonal_blocks_d[n_diag_blocks % 2].set( + arr=L_diagonal_blocks[-2], stream=h2d_stream ) - L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) - L_lower_arrow_blocks_d[n_diag_blocks % 2].set(arr=L_lower_arrow_blocks[-2], stream=h2d_stream) - L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) - + L_lower_arrow_blocks_d[n_diag_blocks % 2].set( + arr=L_lower_arrow_blocks[-2], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[n_diag_blocks % 2].set( + arr=L_lower_diagonal_blocks[-1], stream=h2d_stream + ) + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) for i in range(n_diag_blocks - 2, -1, -1): - + if i > 0: # Pass new blocks h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) - B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) - L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) - L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) - L_lower_arrow_blocks_d[(i - 1) % 2].set(arr=L_lower_arrow_blocks[i - 1], stream=h2d_stream) + B_d[(i - 1) % 2].set( + arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], + stream=h2d_stream, + ) + L_diagonal_blocks_d[(i - 1) % 2].set( + arr=L_diagonal_blocks[i - 1], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[(i - 1) % 2].set( + arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream + ) + L_lower_arrow_blocks_d[(i - 1) % 2].set( + arr=L_lower_arrow_blocks[i - 1], stream=h2d_stream + ) h2d_events[i % 2].record(stream=h2d_stream) - + with compute_stream: # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) - + B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] @@ -583,10 +607,9 @@ def _pobtas_streaming( d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) B_previous_d[(i - 1) % 2].get( - out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], - stream=d2h_stream, - blocking=False - + out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream=d2h_stream, + blocking=False, ) d2h_events[i % 2].record(stream=d2h_stream) @@ -597,6 +620,5 @@ def _pobtas_streaming( else: raise ValueError(f"Invalid transpose argument: {trans}.") - - cp.cuda.Device().synchronize() \ No newline at end of file + cp.cuda.Device().synchronize() diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 295be756..dc570116 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -181,13 +181,9 @@ def _pobts_streaming( # Device Buffers # B Buffers - B_shape = B[0 : diag_blocksize] - B_d = cp.empty( - (2, *B_shape.shape), dtype=B_shape.dtype - ) - B_previous_d = cp.empty( - (2, *B_shape.shape), dtype=B_shape.dtype - ) + B_shape = B[0:diag_blocksize] + B_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype) + B_previous_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype) del B_shape # L Buffers @@ -213,12 +209,11 @@ def _pobts_streaming( h2d_events[1].record(stream=h2d_stream) if n_diag_blocks > 1: - B_d[1].set( - arr=B[diag_blocksize : (2 * diag_blocksize)], - stream=h2d_stream - ) + B_d[1].set(arr=B[diag_blocksize : (2 * diag_blocksize)], stream=h2d_stream) L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) - L_lower_diagonal_blocks_d[1].set(arr=L_lower_diagonal_blocks[0], stream=h2d_stream) + L_lower_diagonal_blocks_d[1].set( + arr=L_lower_diagonal_blocks[0], stream=h2d_stream + ) h2d_events[0].record(stream=h2d_stream) @@ -226,28 +221,33 @@ def _pobts_streaming( # Solve first B block compute_stream.wait_event(h2d_events[1]) - B_previous_d[0] = ( - cu_la.solve_triangular( - L_diagonal_blocks_d[0], - B_d[0], - lower=True, - ) + B_previous_d[0] = cu_la.solve_triangular( + L_diagonal_blocks_d[0], + B_d[0], + lower=True, ) compute_B_events[0].record(stream=compute_stream) for i in range(1, n_diag_blocks): - + if i + 1 < n_diag_blocks: # Pass next blocks h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) - B_d[(i + 1) % 2].set(arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=h2d_stream) - L_diagonal_blocks_d[(i + 1) % 2].set(arr=L_diagonal_blocks[i + 1], stream=h2d_stream) - L_lower_diagonal_blocks_d[(i + 1) % 2].set(arr=L_lower_diagonal_blocks[i], stream=h2d_stream) - + B_d[(i + 1) % 2].set( + arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream=h2d_stream, + ) + L_diagonal_blocks_d[(i + 1) % 2].set( + arr=L_diagonal_blocks[i + 1], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[(i + 1) % 2].set( + arr=L_lower_diagonal_blocks[i], stream=h2d_stream + ) + h2d_events[i % 2].record(stream=h2d_stream) - + with compute_stream: # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} compute_stream.wait_event(h2d_events[(i + 1) % 2]) @@ -256,8 +256,7 @@ def _pobts_streaming( B_previous_d[i % 2] = cu_la.solve_triangular( L_diagonal_blocks_d[i % 2], B_d[i % 2] - - L_lower_diagonal_blocks_d[i % 2] - @ B_previous_d[(i + 1) % 2], + - L_lower_diagonal_blocks_d[i % 2] @ B_previous_d[(i + 1) % 2], lower=True, ) @@ -265,38 +264,44 @@ def _pobts_streaming( # Pass previous B block back d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) - + B_previous_d[(i + 1) % 2].get( - out=B[(i - 1) * diag_blocksize : i * diag_blocksize], - stream=d2h_stream, - blocking=False + out=B[(i - 1) * diag_blocksize : i * diag_blocksize], + stream=d2h_stream, + blocking=False, ) - + d2h_events[i % 2].record(stream=d2h_stream) # Pass last B block back d2h_stream.wait_event(compute_B_events[(n_diag_blocks + 1) % 2]) - - B_previous_d[(n_diag_blocks + 1) % 2].get(out=B[-diag_blocksize:], stream=d2h_stream, blocking=False) - - + + B_previous_d[(n_diag_blocks + 1) % 2].get( + out=B[-diag_blocksize:], stream=d2h_stream, blocking=False + ) + elif trans == "T" or trans == "C": # ----- Backward substitution ----- # --- H2D: transfers --- B_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) - L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(arr=L_diagonal_blocks[-1], stream=h2d_stream) + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_diagonal_blocks[-1], stream=h2d_stream + ) h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) if n_diag_blocks > 1: B_d[n_diag_blocks % 2].set( - arr=B[-(2 * diag_blocksize) : -diag_blocksize], - stream=h2d_stream + arr=B[-(2 * diag_blocksize) : -diag_blocksize], stream=h2d_stream + ) + L_diagonal_blocks_d[n_diag_blocks % 2].set( + arr=L_diagonal_blocks[-2], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[n_diag_blocks % 2].set( + arr=L_lower_diagonal_blocks[-1], stream=h2d_stream ) - L_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_diagonal_blocks[-2], stream=h2d_stream) - L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(arr=L_lower_diagonal_blocks[-1], stream=h2d_stream) h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) @@ -304,31 +309,34 @@ def _pobts_streaming( # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) - B_previous_d[(n_diag_blocks - 1) % 2] = ( - cu_la.solve_triangular( - L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], - B_d[(n_diag_blocks - 1) % 2], - lower=True, - trans="C", - ) + B_previous_d[(n_diag_blocks - 1) % 2] = cu_la.solve_triangular( + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2], + lower=True, + trans="C", ) compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) - - for i in range(n_diag_blocks - 2, -1, -1): - + if i > 0: # pass next blocks h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) - B_d[(i - 1) % 2].set(arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], stream=h2d_stream) - L_diagonal_blocks_d[(i - 1) % 2].set(arr=L_diagonal_blocks[i - 1], stream=h2d_stream) - L_lower_diagonal_blocks_d[(i - 1) % 2].set(arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream) + B_d[(i - 1) % 2].set( + arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], + stream=h2d_stream, + ) + L_diagonal_blocks_d[(i - 1) % 2].set( + arr=L_diagonal_blocks[i - 1], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[(i - 1) % 2].set( + arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream + ) h2d_events[i % 2].record(stream=h2d_stream) - + with compute_stream: # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} compute_stream.wait_event(h2d_events[(i - 1) % 2]) @@ -348,16 +356,20 @@ def _pobts_streaming( # Pass previous B block back d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) - B_previous_d[(i - 1) % 2].get(out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], stream=d2h_stream, blocking=False) - + B_previous_d[(i - 1) % 2].get( + out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream=d2h_stream, + blocking=False, + ) + d2h_events[i % 2].record(stream=d2h_stream) # Pass last B block back d2h_stream.wait_event(compute_B_events[0]) - + B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) else: raise ValueError(f"Invalid transpose argument: {trans}.") - - cp.cuda.Device().synchronize() \ No newline at end of file + + cp.cuda.Device().synchronize() diff --git a/src/serinv/wrappers/ddbtars.py b/src/serinv/wrappers/ddbtars.py index beb34390..d836d25a 100644 --- a/src/serinv/wrappers/ddbtars.py +++ b/src/serinv/wrappers/ddbtars.py @@ -13,7 +13,6 @@ import cupyx as cpx - def allocate_ddbtars( A_diagonal_blocks: ArrayLike, A_lower_diagonal_blocks: ArrayLike, diff --git a/src/serinv/wrappers/pddbtasc.py b/src/serinv/wrappers/pddbtasc.py index 3b9a4d4a..da2fcff4 100644 --- a/src/serinv/wrappers/pddbtasc.py +++ b/src/serinv/wrappers/pddbtasc.py @@ -42,7 +42,7 @@ def pddbtasc( The arrow tip block of the block tridiagonal with arrowhead matrix. comm : MPI.Comm The MPI communicator. Default is MPI.COMM_WORLD. - + Keyword Arguments ----------------- rhs : dict diff --git a/src/serinv/wrappers/pddbtasci.py b/src/serinv/wrappers/pddbtasci.py index f86b6a9c..0ed92861 100644 --- a/src/serinv/wrappers/pddbtasci.py +++ b/src/serinv/wrappers/pddbtasci.py @@ -43,7 +43,7 @@ def pddbtasci( The arrow tip block of the block tridiagonal with arrowhead matrix. comm : MPI.Comm The MPI communicator. Default is MPI.COMM_WORLD. - + Keyword Arguments ----------------- rhs : dict diff --git a/src/serinv/wrappers/pddbtsc.py b/src/serinv/wrappers/pddbtsc.py index 357b08f6..fc0a3765 100644 --- a/src/serinv/wrappers/pddbtsc.py +++ b/src/serinv/wrappers/pddbtsc.py @@ -33,7 +33,7 @@ def pddbtsc( The upper diagonal blocks of the block tridiagonal with arrowhead matrix. comm : MPI.Comm The MPI communicator. Default is MPI.COMM_WORLD. - + Keyword Arguments ----------------- rhs : dict @@ -179,4 +179,4 @@ def pddbtsc( quadratic=quadratic, ) - comm.Barrier() \ No newline at end of file + comm.Barrier() diff --git a/src/serinv/wrappers/pddbtsci.py b/src/serinv/wrappers/pddbtsci.py index 144054f8..9c494e13 100644 --- a/src/serinv/wrappers/pddbtsci.py +++ b/src/serinv/wrappers/pddbtsci.py @@ -34,7 +34,7 @@ def pddbtsci( The upper diagonal blocks of the block tridiagonal matrix. comm : MPI.Comm The MPI communicator. Default is MPI.COMM_WORLD. - + Keyword Arguments ----------------- rhs : dict diff --git a/tests/conftest.py b/tests/conftest.py index 4d15c7db..25b716a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,7 @@ pytest.param(3, id="diagonal_blocksize=3"), ] + @pytest.fixture(params=ARRAY_TYPE, autouse=True) def array_type(request: pytest.FixtureRequest) -> str: return request.param diff --git a/tests/tests_algs/permuted/test_bt/test_pobts_permuted.py b/tests/tests_algs/permuted/test_bt/test_pobts_permuted.py index bba1068a..7a934c61 100644 --- a/tests/tests_algs/permuted/test_bt/test_pobts_permuted.py +++ b/tests/tests_algs/permuted/test_bt/test_pobts_permuted.py @@ -44,9 +44,7 @@ def test_pobts_permuted( A_diagonal_blocks, A_lower_diagonal_blocks, _, - ) = bt_dense_to_arrays( - A.copy(), diagonal_blocksize, n_diag_blocks - ) + ) = bt_dense_to_arrays(A.copy(), diagonal_blocksize, n_diag_blocks) # Allocate permutation buffer buffer = allocate_pobtx_permutation_buffers( diff --git a/tests/tests_algs/regular/tests_bta/test_pobtaf.py b/tests/tests_algs/regular/tests_bta/test_pobtaf.py index 98756357..920c6292 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtaf.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtaf.py @@ -25,6 +25,7 @@ def array_type(request: pytest.FixtureRequest) -> str: return request.param + @pytest.mark.mpi_skip() def test_pobtaf( diagonal_blocksize: int, diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index d58e3a19..ffc290c2 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -25,6 +25,7 @@ def array_type(request: pytest.FixtureRequest) -> str: return request.param + @pytest.mark.mpi_skip() @pytest.mark.parametrize("n_rhs", [1, 2, 3]) def test_pobtas( @@ -35,7 +36,7 @@ def test_pobtas( array_type: str, dtype: np.dtype, ): - + A = dd_bta( diagonal_blocksize, arrowhead_blocksize, From 6ff4ed716dce6852034466a3861f784d8f49fb8b Mon Sep 17 00:00:00 2001 From: 03szust Date: Fri, 6 Jun 2025 09:04:02 +0000 Subject: [PATCH 264/518] changed errors --- src/serinv/algs/pobtas.py | 2 +- src/serinv/algs/pobts.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 1575a4ac..60d2b97d 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -235,7 +235,7 @@ def _pobtas_streaming( ): arr_module, _ = _get_module_from_array(arr=L_diagonal_blocks) if arr_module.__name__ != "numpy": - raise NotImplementedError( + raise TypeError( "Host<->Device streaming only works when host-arrays are given." ) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 295be756..d065b160 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -164,7 +164,7 @@ def _pobts_streaming( ): arr_module, _ = _get_module_from_array(arr=L_diagonal_blocks) if arr_module.__name__ != "numpy": - raise NotImplementedError( + raise TypeError( "Host<->Device streaming only works when host-arrays are given." ) From 7c4f97bc02f1a00895c99a4a7e390ae79b56c88f Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 07:43:36 +0000 Subject: [PATCH 265/518] removed local functions from pobtas to prepare for renaming and moving --- src/serinv/algs/pobtas.py | 6 +++--- src/serinv/utils/matmul.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 72a4c751..ec7ef158 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -82,15 +82,15 @@ def _pobtas( if trans == "N": # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = serinv_solve_triangular( + B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], lower=True, side=0 ) - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= serinv_matmul( + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( L_lower_diagonal_blocks[i] - , B[i * diag_blocksize : (i + 1) * diag_blocksize] + @ B[i * diag_blocksize : (i + 1) * diag_blocksize] ) B[-arrow_blocksize:] -= ( diff --git a/src/serinv/utils/matmul.py b/src/serinv/utils/matmul.py index a2cc7d27..f57e401a 100644 --- a/src/serinv/utils/matmul.py +++ b/src/serinv/utils/matmul.py @@ -9,7 +9,7 @@ try: import cupy as cp - from cupy.cublas import gemm + from cupy.cublas import gemm as cp_gemm except (ImportError, ImportWarning, ModuleNotFoundError): pass @@ -20,7 +20,7 @@ def serinv_matmul (a, b, trans_a = 'N', trans_b = 'N'): if xp == np: return matmul_gemm_host(a, b, trans_a=trans_a, trans_b=trans_b) elif xp == cp: - return gemm(trans_a, trans_b, a, b) + return cp_gemm(trans_a, trans_b, a, b) else: ModuleNotFoundError("Unknown Module") From b56e59156a37ae91a84c79a40da13641b3b7ceb6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 07:54:34 +0000 Subject: [PATCH 266/518] changed solve_triangluar to trsm in pobtaf --- src/serinv/algs/pobtaf.py | 25 ++++++++++--------- src/serinv/algs/pobtas.py | 1 - src/serinv/block_primitive/__init__.py | 7 ++++++ .../matmul.py => block_primitive/gemm.py} | 2 +- src/serinv/{utils => block_primitive}/trsm.py | 2 +- src/serinv/utils/__init__.py | 5 +--- 6 files changed, 23 insertions(+), 19 deletions(-) create mode 100644 src/serinv/block_primitive/__init__.py rename src/serinv/{utils/matmul.py => block_primitive/gemm.py} (98%) rename src/serinv/{utils => block_primitive}/trsm.py (99%) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 05f58d3b..4a99ca19 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -7,6 +7,7 @@ _get_cholesky, ) +from serinv.block_primitive import trsm, gemm def pobtaf( A_diagonal_blocks: ArrayLike, @@ -118,7 +119,7 @@ def _pobtaf( # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :].conj().T, lower=True, @@ -129,7 +130,7 @@ def _pobtaf( # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], A_lower_arrow_blocks[i, :, :].conj().T, lower=True, @@ -164,7 +165,7 @@ def _pobtaf( # L_{ndb+1, ndb} = A_{ndb+1, ndb} @ L_{ndb, ndb}^{-T} L_lower_arrow_blocks[-1, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[-1, :, :], A_lower_arrow_blocks[-1, :, :].conj().T, lower=True, @@ -210,7 +211,7 @@ def _pobtaf_permuted( # Compute lower factors # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :].conj().T, lower=True, @@ -221,7 +222,7 @@ def _pobtaf_permuted( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} buffer[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], buffer[i, :, :].conj().T, lower=True, @@ -232,7 +233,7 @@ def _pobtaf_permuted( # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], A_lower_arrow_blocks[i, :, :].conj().T, lower=True, @@ -391,7 +392,7 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -418,7 +419,7 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -488,7 +489,7 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) if factorize_last_block: L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] = ( - cu_la.solve_triangular( + cu_trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T, lower=True, @@ -654,7 +655,7 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -681,7 +682,7 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -701,7 +702,7 @@ def _pobtaf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - cu_la.solve_triangular( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, lower=True, diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index ec7ef158..f4b825d2 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -7,7 +7,6 @@ ) -from serinv.utils import serinv_matmul, serinv_solve_triangular def pobtas( L_diagonal_blocks: ArrayLike, diff --git a/src/serinv/block_primitive/__init__.py b/src/serinv/block_primitive/__init__.py new file mode 100644 index 00000000..b8cf0541 --- /dev/null +++ b/src/serinv/block_primitive/__init__.py @@ -0,0 +1,7 @@ +from serinv.block_primitive.gemm import gemm +from serinv.block_primitive.trsm import trsm + +__all__ = [ + "gemm", + "trsm" +] \ No newline at end of file diff --git a/src/serinv/utils/matmul.py b/src/serinv/block_primitive/gemm.py similarity index 98% rename from src/serinv/utils/matmul.py rename to src/serinv/block_primitive/gemm.py index f57e401a..975a4b85 100644 --- a/src/serinv/utils/matmul.py +++ b/src/serinv/block_primitive/gemm.py @@ -13,7 +13,7 @@ except (ImportError, ImportWarning, ModuleNotFoundError): pass -def serinv_matmul (a, b, trans_a = 'N', trans_b = 'N'): +def gemm (a, b, trans_a = 'N', trans_b = 'N'): """Wrapper to call GeMM for host or device""" xp, la = _get_module_from_array(a) diff --git a/src/serinv/utils/trsm.py b/src/serinv/block_primitive/trsm.py similarity index 99% rename from src/serinv/utils/trsm.py rename to src/serinv/block_primitive/trsm.py index c8214a9b..1809289a 100644 --- a/src/serinv/utils/trsm.py +++ b/src/serinv/block_primitive/trsm.py @@ -14,7 +14,7 @@ except (ImportError, ImportWarning, ModuleNotFoundError): pass -def serinv_solve_triangular(a, b, trans=0, lower = False, unit_diagonal=False, +def trsm(a, b, trans=0, lower = False, unit_diagonal=False, overwrite_b=False, check_finite=False, side=0): """Wrapper for the trsm function to call depending on wheter the solve happens on the host or the device diff --git a/src/serinv/utils/__init__.py b/src/serinv/utils/__init__.py index 00a5a351..1c54d228 100644 --- a/src/serinv/utils/__init__.py +++ b/src/serinv/utils/__init__.py @@ -8,8 +8,7 @@ from serinv.utils.pobtx import allocate_pobtx_permutation_buffers from serinv.utils.pobtax import allocate_pobtax_permutation_buffers -from serinv.utils.matmul import serinv_matmul -from serinv.utils.trsm import serinv_solve_triangular + __all__ = [ "check_block_dd", @@ -18,6 +17,4 @@ "allocate_ddbtax_permutation_buffers", "allocate_pobtx_permutation_buffers", "allocate_pobtax_permutation_buffers", - "serinv_matmul", - "serinv_solve_triangluar" ] From bdea80fab4b23c4e5d845fecaab7fdf584f5a236 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 07:56:01 +0000 Subject: [PATCH 267/518] =?UTF-8?q?removed=20l=C3=B6eftover=20side=20param?= =?UTF-8?q?eter=20from=20pobtas?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index f4b825d2..dbc2d916 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -84,7 +84,7 @@ def _pobtas( B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, side=0 + lower=True ) B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( From 3bd5108f786b8853448f17bb1672c8677e07cd2f Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 07:57:04 +0000 Subject: [PATCH 268/518] removed double conjugate once for testing side --- src/serinv/algs/pobtaf.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 4a99ca19..7cbf38ff 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -121,11 +121,9 @@ def _pobtaf( L_lower_diagonal_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks[i, :, :], + lower=True, side = 1 ) - .conj() - .T ) # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} From 5f25de7bb03423374d1da0f68c52770555fcd7d6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:02:02 +0000 Subject: [PATCH 269/518] test tom check schape and size of trsm with side --- src/serinv/algs/pobtaf.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 7cbf38ff..8af9708e 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -117,13 +117,30 @@ def _pobtaf( # L_{i, i} = chol(A_{i, i}) L_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :]) + + ### + # Testing for shape and size + print(L_diagonal_blocks[i, :, :]) + print(A_lower_diagonal_blocks[i, :, :].conj().T) + print(A_lower_diagonal_blocks[i, :, :]) + + L_test = trsm( + L_diagonal_blocks[i, :, :], + A_lower_diagonal_blocks[i, :, :], + lower=True, side = 1 + ) + print(L_test) + ### + # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :], - lower=True, side = 1 + A_lower_diagonal_blocks[i, :, :].conj().T, + lower=True, ) + .conj() + .T ) # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} From 332e1ebbe3e0100c698aea1734bfe63902a4839e Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:06:12 +0000 Subject: [PATCH 270/518] forced error to test shapes --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 8af9708e..d458ade7 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -137,7 +137,7 @@ def _pobtaf( trsm( L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + lower=True, side = 1 ) .conj() .T From d5c59a41a0745144b1d673b88e13c336058031a4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:07:28 +0000 Subject: [PATCH 271/518] removed some debug messages --- src/serinv/block_primitive/trsm.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/serinv/block_primitive/trsm.py b/src/serinv/block_primitive/trsm.py index 1809289a..0ca1e0ca 100644 --- a/src/serinv/block_primitive/trsm.py +++ b/src/serinv/block_primitive/trsm.py @@ -21,15 +21,10 @@ def trsm(a, b, trans=0, lower = False, unit_diagonal=False, For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept plus the side parameter which can either be 0 or 1 for left or right hand side """ - print(a) xp, la = _get_module_from_array(a) - print(xp) - print(b) if xp == np: - print("three") return solve_triangular_host(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) elif xp == cp: - print("four") return solve_triangular_device(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) else: ModuleNotFoundError("Unknown Module") From 24a7bb0529ff76f26d49fd3f771bcb517fa1829f Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:12:00 +0000 Subject: [PATCH 272/518] more testing --- src/serinv/algs/pobtaf.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index d458ade7..cbcde138 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -120,16 +120,30 @@ def _pobtaf( ### # Testing for shape and size + print("###") print(L_diagonal_blocks[i, :, :]) + print("###") print(A_lower_diagonal_blocks[i, :, :].conj().T) + print("###") print(A_lower_diagonal_blocks[i, :, :]) - + print("###") L_test = trsm( L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :], lower=True, side = 1 ) print(L_test) + print("###") + L_test = ( + trsm( + L_diagonal_blocks[i, :, :], + A_lower_diagonal_blocks[i, :, :].conj().T, + lower=True, side = 0 + ) + .conj() + .T + ) + print(L_test) ### # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} From 11315008611041e481c91596569e8ceb6f0f0074 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:13:13 +0000 Subject: [PATCH 273/518] more debug --- src/serinv/algs/pobtaf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index cbcde138..d5eadefb 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -127,6 +127,7 @@ def _pobtaf( print("###") print(A_lower_diagonal_blocks[i, :, :]) print("###") + print("side = 1 sol") L_test = trsm( L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :], @@ -134,6 +135,7 @@ def _pobtaf( ) print(L_test) print("###") + print("side = 0 sol") L_test = ( trsm( L_diagonal_blocks[i, :, :], From fda845619c3041cb90261c60512e194daa26d303 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:14:18 +0000 Subject: [PATCH 274/518] swapped A and B in test --- src/serinv/algs/pobtaf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index d5eadefb..ee090233 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -129,8 +129,9 @@ def _pobtaf( print("###") print("side = 1 sol") L_test = trsm( - L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :], + L_diagonal_blocks[i, :, :], + lower=True, side = 1 ) print(L_test) From e1b57756f61d9e5a3c7dfb3478bf290c7c747a43 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:17:27 +0000 Subject: [PATCH 275/518] swapped back --- src/serinv/algs/pobtaf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index ee090233..d5eadefb 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -129,9 +129,8 @@ def _pobtaf( print("###") print("side = 1 sol") L_test = trsm( - A_lower_diagonal_blocks[i, :, :], L_diagonal_blocks[i, :, :], - + A_lower_diagonal_blocks[i, :, :], lower=True, side = 1 ) print(L_test) From df626e3718302144590d59d46de46824554a9fc8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:28:49 +0000 Subject: [PATCH 276/518] transpose L --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index d5eadefb..aaa2da25 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -129,7 +129,7 @@ def _pobtaf( print("###") print("side = 1 sol") L_test = trsm( - L_diagonal_blocks[i, :, :], + L_diagonal_blocks[i, :, :].conj.T, A_lower_diagonal_blocks[i, :, :], lower=True, side = 1 ) From 0d60090cf487c32ac4ef531b92074ab64f795b54 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:29:37 +0000 Subject: [PATCH 277/518] typo --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index aaa2da25..6ebb6cbb 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -129,7 +129,7 @@ def _pobtaf( print("###") print("side = 1 sol") L_test = trsm( - L_diagonal_blocks[i, :, :].conj.T, + L_diagonal_blocks[i, :, :].conj().T, A_lower_diagonal_blocks[i, :, :], lower=True, side = 1 ) From 142acdc89d332a76b2cc9b257054a7902869fcc4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:31:28 +0000 Subject: [PATCH 278/518] changed conj.T to the trans param --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 6ebb6cbb..7c004630 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -129,9 +129,9 @@ def _pobtaf( print("###") print("side = 1 sol") L_test = trsm( - L_diagonal_blocks[i, :, :].conj().T, + L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :], - lower=True, side = 1 + trans=2,lower=True, side = 1 ) print(L_test) print("###") From bfbdcf1d1df880be943628d4cb218ee21e4e6e52 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:32:49 +0000 Subject: [PATCH 279/518] actually implement frist trsm --- src/serinv/algs/pobtaf.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 7c004630..68d196ee 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -131,7 +131,7 @@ def _pobtaf( L_test = trsm( L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :], - trans=2,lower=True, side = 1 + trans='C',lower=True, side=1 ) print(L_test) print("###") @@ -152,11 +152,10 @@ def _pobtaf( L_lower_diagonal_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, side = 1 + A_lower_diagonal_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T + ) # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} From e80022c959e314edfe7f18569e5fc412839d5396 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:36:49 +0000 Subject: [PATCH 280/518] changed all trsm's --- src/serinv/algs/pobtaf.py | 103 +++++++++----------------------------- 1 file changed, 25 insertions(+), 78 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 68d196ee..0577b66c 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -117,37 +117,6 @@ def _pobtaf( # L_{i, i} = chol(A_{i, i}) L_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :]) - - ### - # Testing for shape and size - print("###") - print(L_diagonal_blocks[i, :, :]) - print("###") - print(A_lower_diagonal_blocks[i, :, :].conj().T) - print("###") - print(A_lower_diagonal_blocks[i, :, :]) - print("###") - print("side = 1 sol") - L_test = trsm( - L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :], - trans='C',lower=True, side=1 - ) - print(L_test) - print("###") - print("side = 0 sol") - L_test = ( - trsm( - L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, side = 0 - ) - .conj() - .T - ) - print(L_test) - ### - # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( trsm( @@ -162,11 +131,9 @@ def _pobtaf( L_lower_arrow_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i, :, :].conj().T, - lower=True, + A_lower_arrow_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # Update next diagonal block @@ -197,11 +164,9 @@ def _pobtaf( L_lower_arrow_blocks[-1, :, :] = ( trsm( L_diagonal_blocks[-1, :, :], - A_lower_arrow_blocks[-1, :, :].conj().T, - lower=True, + A_lower_arrow_blocks[-1, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} @@ -243,11 +208,9 @@ def _pobtaf_permuted( L_lower_diagonal_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # L_{top, i} = A_{top, i} @ U{i, i}^{-1} @@ -255,21 +218,17 @@ def _pobtaf_permuted( trsm( L_diagonal_blocks[i, :, :], buffer[i, :, :].conj().T, - lower=True, + trans='C',lower=True, side=1 ) - .conj() - .T ) # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i, :, :].conj().T, - lower=True, + A_lower_arrow_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # Update next diagonal block @@ -422,13 +381,11 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) compute_lower_events[i % 2].record(stream=compute_stream) @@ -449,13 +406,11 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_arrow_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_arrow_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) compute_arrow_events[i % 2].record(stream=compute_stream) @@ -519,13 +474,11 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) if factorize_last_block: L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], - A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T, - lower=True, + A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) compute_arrow_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) @@ -685,13 +638,11 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) cp_lower_events[i % 2].record(stream=compute_stream) @@ -712,13 +663,11 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, - lower=True, + trans='C',lower=True, side=1 ) - .conj() - .T ) cp_arrow_events[i % 2].record(stream=compute_stream) @@ -732,13 +681,11 @@ def _pobtaf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, - lower=True, + trans='C',lower=True, side=1 ) - .conj() - .T ) cp_upper_nested_dissection_buffer_events[i % 2].record( stream=compute_stream From 2a8291051c701dee4115a616a57229148d8da451 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:40:36 +0000 Subject: [PATCH 281/518] after the previous version faailed, this is the second attempt --- src/serinv/algs/pobtaf.py | 66 +++++++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 0577b66c..5184a19d 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -164,9 +164,11 @@ def _pobtaf( L_lower_arrow_blocks[-1, :, :] = ( trsm( L_diagonal_blocks[-1, :, :], - A_lower_arrow_blocks[-1, :, :], - trans='C',lower=True, side=1 + A_lower_arrow_blocks[-1, :, :].conj().T, + lower=True, ) + .conj() + .T ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} @@ -208,9 +210,11 @@ def _pobtaf_permuted( L_lower_diagonal_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :], - trans='C',lower=True, side=1 + A_lower_diagonal_blocks[i, :, :].conj().T, + lower=True, ) + .conj() + .T ) # L_{top, i} = A_{top, i} @ U{i, i}^{-1} @@ -218,17 +222,21 @@ def _pobtaf_permuted( trsm( L_diagonal_blocks[i, :, :], buffer[i, :, :].conj().T, - trans='C',lower=True, side=1 + lower=True, ) + .conj() + .T ) # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i, :, :], - trans='C',lower=True, side=1 + A_lower_arrow_blocks[i, :, :].conj().T, + lower=True, ) + .conj() + .T ) # Update next diagonal block @@ -381,11 +389,13 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - trsm( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :], - trans='C',lower=True, side=1 + A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, + lower=True, ) + .conj() + .T ) compute_lower_events[i % 2].record(stream=compute_stream) @@ -406,11 +416,13 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - trsm( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_arrow_blocks_d[i % 2, :, :], - trans='C',lower=True, side=1 + A_lower_arrow_blocks_d[i % 2, :, :].conj().T, + lower=True, ) + .conj() + .T ) compute_arrow_events[i % 2].record(stream=compute_stream) @@ -474,11 +486,13 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) if factorize_last_block: L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] = ( - trsm( + cu_trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], - A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], - trans='C',lower=True, side=1 + A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T, + lower=True, ) + .conj() + .T ) compute_arrow_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) @@ -638,11 +652,13 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - trsm( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :], - trans='C',lower=True, side=1 + A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, + lower=True, ) + .conj() + .T ) cp_lower_events[i % 2].record(stream=compute_stream) @@ -663,11 +679,13 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - trsm( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, - trans='C',lower=True, side=1 + lower=True, ) + .conj() + .T ) cp_arrow_events[i % 2].record(stream=compute_stream) @@ -681,11 +699,13 @@ def _pobtaf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - trsm( + cu_trsm( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, - trans='C',lower=True, side=1 + lower=True, ) + .conj() + .T ) cp_upper_nested_dissection_buffer_events[i % 2].record( stream=compute_stream From df87182a09a9a15bcacdc9096e9bac14c6370a34 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:49:27 +0000 Subject: [PATCH 282/518] removed side from arrow because of dim mismatch, added it to other arrow to check for dim mismatch --- src/serinv/algs/pobtaf.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 5184a19d..c2c01698 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -131,9 +131,11 @@ def _pobtaf( L_lower_arrow_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i, :, :], - trans='C',lower=True, side=1 + A_lower_arrow_blocks[i, :, :].conj().T, + lower=True, ) + .conj() + .T ) # Update next diagonal block @@ -164,11 +166,9 @@ def _pobtaf( L_lower_arrow_blocks[-1, :, :] = ( trsm( L_diagonal_blocks[-1, :, :], - A_lower_arrow_blocks[-1, :, :].conj().T, - lower=True, + A_lower_arrow_blocks[-1, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} From db39806384a2d72db626cf88782d7091dfcada73 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 08:52:03 +0000 Subject: [PATCH 283/518] implemented trsm side right for all non arrow solves --- src/serinv/algs/pobtaf.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index c2c01698..a85a6a03 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -166,9 +166,11 @@ def _pobtaf( L_lower_arrow_blocks[-1, :, :] = ( trsm( L_diagonal_blocks[-1, :, :], - A_lower_arrow_blocks[-1, :, :], - trans='C',lower=True, side=1 + A_lower_arrow_blocks[-1, :, :].conj().T, + lower=True, ) + .conj() + .T ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} @@ -210,11 +212,9 @@ def _pobtaf_permuted( L_lower_diagonal_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # L_{top, i} = A_{top, i} @ U{i, i}^{-1} @@ -389,13 +389,11 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) compute_lower_events[i % 2].record(stream=compute_stream) @@ -416,7 +414,7 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -486,7 +484,7 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) if factorize_last_block: L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T, lower=True, @@ -652,13 +650,11 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) cp_lower_events[i % 2].record(stream=compute_stream) @@ -679,7 +675,7 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -699,7 +695,7 @@ def _pobtaf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - cu_trsm( + trsm( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, lower=True, From 91b6d9bcb2dddf575269bcfb15a17a64e6518a1a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 09:21:36 +0000 Subject: [PATCH 284/518] imported cupy gemm to local --- src/serinv/block_primitive/gemm.py | 157 +++++++++++++++++++++++++++-- 1 file changed, 151 insertions(+), 6 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 975a4b85..78189ede 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -10,6 +10,9 @@ try: import cupy as cp from cupy.cublas import gemm as cp_gemm + from cupy_backends.cuda.libs import cublas + from cupy import _core + from cupy.cuda import device except (ImportError, ImportWarning, ModuleNotFoundError): pass @@ -20,12 +23,12 @@ def gemm (a, b, trans_a = 'N', trans_b = 'N'): if xp == np: return matmul_gemm_host(a, b, trans_a=trans_a, trans_b=trans_b) elif xp == cp: - return cp_gemm(trans_a, trans_b, a, b) + return matmul_gemm_device(trans_a, trans_b, a, b) else: ModuleNotFoundError("Unknown Module") -def matmul_gemm_host(a, b, trans_a=0, trans_b=0, overwrite_c=0, check_finite=False): +def matmul_gemm_host(a, b, alpha=1, beta=0, c=None, trans_a=0, trans_b=0, overwrite_c=0, check_finite=False): """ Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. @@ -95,12 +98,16 @@ def matmul_gemm_host(a, b, trans_a=0, trans_b=0, overwrite_c=0, check_finite=Fal a1 = _asarray_validated(a, check_finite=check_finite) b1 = _asarray_validated(b, check_finite=check_finite) + c1 = _asarray_validated(c, check_finite=check_finite) if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: raise ValueError('expected square matrix') if a1.shape[0] != b1.shape[0]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + if beta != 0 and c1 == None: + raise ValueError('expected C matrix') # accommodate empty arrays if b1.size == 0: @@ -109,12 +116,12 @@ def matmul_gemm_host(a, b, trans_a=0, trans_b=0, overwrite_c=0, check_finite=Fal ).dtype return np.empty_like(b1, dtype=dt_nonempty) - x = _matmul_gemm(a1, b1, trans_a, trans_b, overwrite_c) + x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) return x # solve_triangular without the input validation -def _matmul_gemm(a1, b1, trans_a=0, trans_b=0, overwrite_c=0): +def _matmul_gemm(a1, b1, alpha=1, beta=0, c1=None, trans_a=0, trans_b=0, overwrite_c=0): trans_a = {'N': 0, 'T': 1, 'C': 2}.get(trans_a, trans_a) trans_b = {'N': 0, 'T': 1, 'C': 2}.get(trans_b, trans_b) @@ -127,8 +134,146 @@ def _matmul_gemm(a1, b1, trans_a=0, trans_b=0, overwrite_c=0): alpha = 1 beta = 0 - - x = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) + if beta == 0: + x = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) + else: + x = gemm(alpha, a1, b1, beta, c1, trans_a, trans_b, overwrite_c) return x + + + +def _trans_to_cublas_op(trans): + if trans == 'N' or trans == cublas.CUBLAS_OP_N: + trans = cublas.CUBLAS_OP_N + elif trans == 'T' or trans == cublas.CUBLAS_OP_T: + trans = cublas.CUBLAS_OP_T + elif trans == 'H' or trans == cublas.CUBLAS_OP_C: + trans = cublas.CUBLAS_OP_C + else: + raise TypeError('invalid trans (actual: {})'.format(trans)) + return trans + +def _decide_ld_and_trans(a, trans): + ld = None + if trans in (cublas.CUBLAS_OP_N, cublas.CUBLAS_OP_T): + if a._f_contiguous: + ld = a.shape[0] + elif a._c_contiguous: + ld = a.shape[1] + trans = 1 - trans + return ld, trans + + +def _change_order_if_necessary(a, lda): + if lda is None: + lda = a.shape[0] + if not a._f_contiguous: + a = a.copy(order='F') + return a, lda + +def _get_scalar_ptr(a, dtype): + if isinstance(a, cp.ndarray): + if a.dtype != dtype: + a = cp.array(a, dtype=dtype) + a_ptr = a.data.ptr + else: + if not (isinstance(a, np.ndarray) and a.dtype == dtype): + a = np.array(a, dtype=dtype) + a_ptr = a.ctypes.data + return a, a_ptr + + +def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): + """Computes out = alpha * op(a) @ op(b) + beta * out + + op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', + op(a) = a.T.conj() if transa is 'H'. + op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', + op(b) = b.T.conj() if transb is 'H'. + """ + assert a.ndim == b.ndim == 2 + assert a.dtype == b.dtype + dtype = a.dtype.char + if dtype == 'f': + func = cublas.sgemm + elif dtype == 'd': + func = cublas.dgemm + elif dtype == 'F': + func = cublas.cgemm + elif dtype == 'D': + func = cublas.zgemm + else: + raise TypeError('invalid dtype') + + + transa = _trans_to_cublas_op(transa) + transb = _trans_to_cublas_op(transb) + if transa == cublas.CUBLAS_OP_N: + m, k = a.shape + else: + k, m = a.shape + if transb == cublas.CUBLAS_OP_N: + n = b.shape[1] + assert b.shape[0] == k + else: + n = b.shape[0] + assert b.shape[1] == k + if out is None: + out = cp.empty((m, n), dtype=dtype, order='F') + beta = 0.0 + else: + assert out.ndim == 2 + assert out.shape == (m, n) + assert out.dtype == dtype + + alpha, alpha_ptr = _get_scalar_ptr(alpha, a.dtype) + beta, beta_ptr = _get_scalar_ptr(beta, a.dtype) + handle = device.get_cublas_handle() + orig_mode = cublas.getPointerMode(handle) + if isinstance(alpha, cp.ndarray) or isinstance(beta, cp.ndarray): + if not isinstance(alpha, cp.ndarray): + alpha = cp.array(alpha) + alpha_ptr = alpha.data.ptr + if not isinstance(beta, cp.ndarray): + beta = cp.array(beta) + beta_ptr = beta.data.ptr + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_DEVICE) + else: + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_HOST) + + lda, transa = _decide_ld_and_trans(a, transa) + ldb, transb = _decide_ld_and_trans(b, transb) + if not (lda is None or ldb is None): + if out._f_contiguous: + try: + func(handle, transa, transb, m, n, k, alpha_ptr, + a.data.ptr, lda, b.data.ptr, ldb, beta_ptr, out.data.ptr, + m) + finally: + cublas.setPointerMode(handle, orig_mode) + return out + elif out._c_contiguous: + # Computes out.T = alpha * b.T @ a.T + beta * out.T + try: + func(handle, 1 - transb, 1 - transa, n, m, k, alpha_ptr, + b.data.ptr, ldb, a.data.ptr, lda, beta_ptr, out.data.ptr, + n) + finally: + cublas.setPointerMode(handle, orig_mode) + return out + + a, lda = _change_order_if_necessary(a, lda) + b, ldb = _change_order_if_necessary(b, ldb) + c = out + if not out._f_contiguous: + c = out.copy(order='F') + try: + func(handle, transa, transb, m, n, k, alpha_ptr, a.data.ptr, lda, + b.data.ptr, ldb, beta_ptr, c.data.ptr, m) + finally: + cublas.setPointerMode(handle, orig_mode) + if not out._f_contiguous: + _core.elementwise_copy(c, out) + return out \ No newline at end of file From d8c1d7a5a7ca6a8e74f334b82c3eb6b1f079b3fc Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 09:22:42 +0000 Subject: [PATCH 285/518] added error to test --- src/serinv/block_primitive/gemm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 78189ede..1f270211 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -193,6 +193,7 @@ def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', op(b) = b.T.conj() if transb is 'H'. """ + ValueError("TEST") assert a.ndim == b.ndim == 2 assert a.dtype == b.dtype dtype = a.dtype.char From 87302016a2a95a8166bcd748fdd59dc5cfc1e298 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 09:23:07 +0000 Subject: [PATCH 286/518] fixed error --- src/serinv/block_primitive/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 1f270211..ef3c39c1 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -193,7 +193,7 @@ def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', op(b) = b.T.conj() if transb is 'H'. """ - ValueError("TEST") + raise ValueError("TEST") assert a.ndim == b.ndim == 2 assert a.dtype == b.dtype dtype = a.dtype.char From e89414630ce0e20300c67bfbec32e46b0ce2d31a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 09:25:11 +0000 Subject: [PATCH 287/518] implemented one provsionary gemm in pobtaf --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index a85a6a03..267236f9 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -142,8 +142,8 @@ def _pobtaf( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( A_diagonal_blocks[i + 1, :, :] - - L_lower_diagonal_blocks[i, :, :] - @ L_lower_diagonal_blocks[i, :, :].conj().T + - gemm(L_lower_diagonal_blocks[i, :, :] + , L_lower_diagonal_blocks[i, :, :].conj().T) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T From d2c3cd403a427431c507aa68a3847a22eb713941 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 09:26:56 +0000 Subject: [PATCH 288/518] removed error for testing --- src/serinv/block_primitive/gemm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index ef3c39c1..78189ede 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -193,7 +193,6 @@ def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', op(b) = b.T.conj() if transb is 'H'. """ - raise ValueError("TEST") assert a.ndim == b.ndim == 2 assert a.dtype == b.dtype dtype = a.dtype.char From d251adb965cc4600e4152f772a375674878dd9b3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 09:28:13 +0000 Subject: [PATCH 289/518] fixed validating array if no array was present to begin with --- src/serinv/block_primitive/gemm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 78189ede..4e332686 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -98,7 +98,8 @@ def matmul_gemm_host(a, b, alpha=1, beta=0, c=None, trans_a=0, trans_b=0, overwr a1 = _asarray_validated(a, check_finite=check_finite) b1 = _asarray_validated(b, check_finite=check_finite) - c1 = _asarray_validated(c, check_finite=check_finite) + if c != None: + c1 = _asarray_validated(c, check_finite=check_finite) if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: raise ValueError('expected square matrix') From 58429da259eb1b7f8ccdc091ff0ef188736ad740 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 09:28:57 +0000 Subject: [PATCH 290/518] fixed c1 one not existing if c was none --- src/serinv/block_primitive/gemm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 4e332686..c7c331bb 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -100,6 +100,8 @@ def matmul_gemm_host(a, b, alpha=1, beta=0, c=None, trans_a=0, trans_b=0, overwr b1 = _asarray_validated(b, check_finite=check_finite) if c != None: c1 = _asarray_validated(c, check_finite=check_finite) + else: + c1 = None if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: raise ValueError('expected square matrix') From af86da81b9da30d29196a37bbd029f5a168da747 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:30:57 +0000 Subject: [PATCH 291/518] changed gemm to accomodate in place operations --- src/serinv/algs/pobtaf.py | 7 ++++--- src/serinv/block_primitive/gemm.py | 12 ++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 267236f9..bf8a129a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -141,9 +141,10 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - A_diagonal_blocks[i + 1, :, :] - - gemm(L_lower_diagonal_blocks[i, :, :] - , L_lower_diagonal_blocks[i, :, :].conj().T) + + gemm(L_lower_diagonal_blocks[i, :, :] + , L_lower_diagonal_blocks[i, :, :].conj().T, + A_diagonal_blocks[i + 1, :, :], -1.0, 1.0) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index c7c331bb..6d290e81 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -9,26 +9,25 @@ try: import cupy as cp - from cupy.cublas import gemm as cp_gemm from cupy_backends.cuda.libs import cublas from cupy import _core from cupy.cuda import device except (ImportError, ImportWarning, ModuleNotFoundError): pass -def gemm (a, b, trans_a = 'N', trans_b = 'N'): +def gemm (a, b, c=None, alpha=1.0, beta=0.0, trans_a ='N', trans_b ='N'): """Wrapper to call GeMM for host or device""" xp, la = _get_module_from_array(a) if xp == np: return matmul_gemm_host(a, b, trans_a=trans_a, trans_b=trans_b) elif xp == cp: - return matmul_gemm_device(trans_a, trans_b, a, b) + return matmul_gemm_device(trans_a, trans_b, a, b, c, alpha, beta) else: ModuleNotFoundError("Unknown Module") -def matmul_gemm_host(a, b, alpha=1, beta=0, c=None, trans_a=0, trans_b=0, overwrite_c=0, check_finite=False): +def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, overwrite_c=0, check_finite=False): """ Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. @@ -124,7 +123,7 @@ def matmul_gemm_host(a, b, alpha=1, beta=0, c=None, trans_a=0, trans_b=0, overwr # solve_triangular without the input validation -def _matmul_gemm(a1, b1, alpha=1, beta=0, c1=None, trans_a=0, trans_b=0, overwrite_c=0): +def _matmul_gemm(a1, b1, alpha=1.0, beta=0.0, c1=None, trans_a=0, trans_b=0, overwrite_c=0): trans_a = {'N': 0, 'T': 1, 'C': 2}.get(trans_a, trans_a) trans_b = {'N': 0, 'T': 1, 'C': 2}.get(trans_b, trans_b) @@ -146,7 +145,7 @@ def _matmul_gemm(a1, b1, alpha=1, beta=0, c1=None, trans_a=0, trans_b=0, overwri return x - +# Util functions for cupy gemm def _trans_to_cublas_op(trans): if trans == 'N' or trans == cublas.CUBLAS_OP_N: trans = cublas.CUBLAS_OP_N @@ -186,6 +185,7 @@ def _get_scalar_ptr(a, dtype): a = np.array(a, dtype=dtype) a_ptr = a.ctypes.data return a, a_ptr +# Util functions for cupy gemm end def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): From 575962169bdc0fcbf7e19fa623db7166eaffbcde Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:33:57 +0000 Subject: [PATCH 292/518] changed first gemm to trans_b = c --- src/serinv/algs/pobtaf.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index bf8a129a..993da40b 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -141,10 +141,11 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - - gemm(L_lower_diagonal_blocks[i, :, :] - , L_lower_diagonal_blocks[i, :, :].conj().T, - A_diagonal_blocks[i + 1, :, :], -1.0, 1.0) + A_diagonal_blocks[i + 1, :, :] + - gemm(L_lower_diagonal_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + trans_b='C' + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T From f3fb2b5f19d16da3cbebd81fd82bb177476fa511 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:35:20 +0000 Subject: [PATCH 293/518] fixed different trans name --- src/serinv/block_primitive/gemm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 6d290e81..a3ec4d77 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -151,7 +151,7 @@ def _trans_to_cublas_op(trans): trans = cublas.CUBLAS_OP_N elif trans == 'T' or trans == cublas.CUBLAS_OP_T: trans = cublas.CUBLAS_OP_T - elif trans == 'H' or trans == cublas.CUBLAS_OP_C: + elif trans == 'C' or trans == cublas.CUBLAS_OP_C: trans = cublas.CUBLAS_OP_C else: raise TypeError('invalid trans (actual: {})'.format(trans)) @@ -192,9 +192,9 @@ def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): """Computes out = alpha * op(a) @ op(b) + beta * out op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', - op(a) = a.T.conj() if transa is 'H'. + op(a) = a.T.conj() if transa is 'C'. op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', - op(b) = b.T.conj() if transb is 'H'. + op(b) = b.T.conj() if transb is 'C'. """ assert a.ndim == b.ndim == 2 assert a.dtype == b.dtype From a48f5687f6c8db4605e7f11ceccaa1b9d1533dab Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:36:19 +0000 Subject: [PATCH 294/518] used alpha param on first gemm --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 993da40b..08b2322a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -142,9 +142,9 @@ def _pobtaf( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( A_diagonal_blocks[i + 1, :, :] - - gemm(L_lower_diagonal_blocks[i, :, :], + + gemm(L_lower_diagonal_blocks[i, :, :], L_lower_diagonal_blocks[i, :, :], - trans_b='C' + trans_b='C', alpha=-1.0 ) ) From 69da4ae227006f02df583d8a254ce6a34d8ca89a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:51:23 +0000 Subject: [PATCH 295/518] removed alpha and beta hardcoding --- src/serinv/block_primitive/gemm.py | 2 -- src/serinv/block_primitive/trsm.py | 1 - 2 files changed, 3 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index a3ec4d77..44673155 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -134,8 +134,6 @@ def _matmul_gemm(a1, b1, alpha=1.0, beta=0.0, c1=None, trans_a=0, trans_b=0, ove else: dtype = np.promote_types(a1.dtype.char, 'f') - alpha = 1 - beta = 0 if beta == 0: x = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) else: diff --git a/src/serinv/block_primitive/trsm.py b/src/serinv/block_primitive/trsm.py index 0ca1e0ca..5f447e27 100644 --- a/src/serinv/block_primitive/trsm.py +++ b/src/serinv/block_primitive/trsm.py @@ -305,7 +305,6 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) trsm, = get_blas_funcs(('trsm',), (a1, b1)) - print(trsm) if a1.dtype.char in 'fd': dtype = a1.dtype From 4e6a81ae6f54bcc362da19ecde10db7afd979909 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:53:47 +0000 Subject: [PATCH 296/518] changed to minus --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 08b2322a..bf8f9eee 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -142,7 +142,7 @@ def _pobtaf( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( A_diagonal_blocks[i + 1, :, :] - + gemm(L_lower_diagonal_blocks[i, :, :], + - gemm(L_lower_diagonal_blocks[i, :, :], L_lower_diagonal_blocks[i, :, :], trans_b='C', alpha=-1.0 ) From f148ab3c1670a7b8a6ebb9c1800a345b258b9d27 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:55:54 +0000 Subject: [PATCH 297/518] inserted some debug messages --- src/serinv/block_primitive/gemm.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 44673155..bff8631c 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -117,7 +117,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) ).dtype return np.empty_like(b1, dtype=dt_nonempty) - + print(alpha) x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) return x @@ -134,6 +134,12 @@ def _matmul_gemm(a1, b1, alpha=1.0, beta=0.0, c1=None, trans_a=0, trans_b=0, ove else: dtype = np.promote_types(a1.dtype.char, 'f') + print(a1) + print("###") + print(b1) + print("###") + print(alpha) + print("###") if beta == 0: x = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) else: From c3de7469b639f47591134b85bd7061fc1c0ccd34 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:56:20 +0000 Subject: [PATCH 298/518] reverted minus --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index bf8f9eee..08b2322a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -142,7 +142,7 @@ def _pobtaf( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( A_diagonal_blocks[i + 1, :, :] - - gemm(L_lower_diagonal_blocks[i, :, :], + + gemm(L_lower_diagonal_blocks[i, :, :], L_lower_diagonal_blocks[i, :, :], trans_b='C', alpha=-1.0 ) From dac04895057f2ebae88a44b86abae9bcb437b294 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 10:58:17 +0000 Subject: [PATCH 299/518] exposed alpha, beta and c for host gemm --- src/serinv/block_primitive/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index bff8631c..13ab9ac1 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -20,7 +20,7 @@ def gemm (a, b, c=None, alpha=1.0, beta=0.0, trans_a ='N', trans_b ='N'): xp, la = _get_module_from_array(a) if xp == np: - return matmul_gemm_host(a, b, trans_a=trans_a, trans_b=trans_b) + return matmul_gemm_host(a, b, c, alpha, beta, trans_a, trans_b) elif xp == cp: return matmul_gemm_device(trans_a, trans_b, a, b, c, alpha, beta) else: From cada2cb8c3d722b0ad8419cdc89534d39927874a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:08:15 +0000 Subject: [PATCH 300/518] convert alpha to complex for cgemm and zgemm host --- src/serinv/block_primitive/gemm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 13ab9ac1..70144426 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -209,8 +209,10 @@ def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): func = cublas.dgemm elif dtype == 'F': func = cublas.cgemm + alpha = complex(alpha) elif dtype == 'D': func = cublas.zgemm + alpha = complex(alpha) else: raise TypeError('invalid dtype') From 13e2b79e3d7aa30e6a6282733b7226004e92054f Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:11:20 +0000 Subject: [PATCH 301/518] inser dytpe debug --- src/serinv/block_primitive/gemm.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 70144426..e5eaea73 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -117,7 +117,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) ).dtype return np.empty_like(b1, dtype=dt_nonempty) - print(alpha) + print(alpha.dtype) x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) return x @@ -134,12 +134,6 @@ def _matmul_gemm(a1, b1, alpha=1.0, beta=0.0, c1=None, trans_a=0, trans_b=0, ove else: dtype = np.promote_types(a1.dtype.char, 'f') - print(a1) - print("###") - print(b1) - print("###") - print(alpha) - print("###") if beta == 0: x = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) else: @@ -209,10 +203,8 @@ def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): func = cublas.dgemm elif dtype == 'F': func = cublas.cgemm - alpha = complex(alpha) elif dtype == 'D': func = cublas.zgemm - alpha = complex(alpha) else: raise TypeError('invalid dtype') From 9d85d220ea29fa36bc2590934ab596b0f2934a19 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:13:04 +0000 Subject: [PATCH 302/518] changed type debug --- src/serinv/block_primitive/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index e5eaea73..81f03e7d 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -117,7 +117,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) ).dtype return np.empty_like(b1, dtype=dt_nonempty) - print(alpha.dtype) + print(alpha.type()) x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) return x From ce4f0bcf8dd70f4557c6f4d4117c79b399ddae92 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:13:49 +0000 Subject: [PATCH 303/518] changed type debug again --- src/serinv/block_primitive/gemm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 81f03e7d..87cfc42c 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -117,6 +117,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) ).dtype return np.empty_like(b1, dtype=dt_nonempty) + print(alpha) print(alpha.type()) x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) return x From fc12dc1e0305f7692ee8d208d26f44bbc565ef91 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:14:55 +0000 Subject: [PATCH 304/518] swapped order in function call --- src/serinv/block_primitive/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 87cfc42c..baa2add4 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -20,7 +20,7 @@ def gemm (a, b, c=None, alpha=1.0, beta=0.0, trans_a ='N', trans_b ='N'): xp, la = _get_module_from_array(a) if xp == np: - return matmul_gemm_host(a, b, c, alpha, beta, trans_a, trans_b) + return matmul_gemm_host(a, b, alpha, beta, c, trans_a, trans_b) elif xp == cp: return matmul_gemm_device(trans_a, trans_b, a, b, c, alpha, beta) else: From 7887bb0aa9b4cfd3b1fbce285c32638ccb7dd0eb Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:15:37 +0000 Subject: [PATCH 305/518] removed debug --- src/serinv/block_primitive/gemm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index baa2add4..e0232472 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -117,8 +117,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) ).dtype return np.empty_like(b1, dtype=dt_nonempty) - print(alpha) - print(alpha.type()) + x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) return x From 1ae08d4ad8bc83ee35af3d03a917bd0942cb2a67 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:22:00 +0000 Subject: [PATCH 306/518] fully use gemm at first location --- src/serinv/algs/pobtaf.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 08b2322a..036d115c 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -141,10 +141,12 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - A_diagonal_blocks[i + 1, :, :] - + gemm(L_lower_diagonal_blocks[i, :, :], - L_lower_diagonal_blocks[i, :, :], - trans_b='C', alpha=-1.0 + + gemm( + L_lower_diagonal_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_diagonal_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 ) ) From cfd7bd7e91aabd43a9b84dea609cbe00aa29e87a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:23:14 +0000 Subject: [PATCH 307/518] changed check for existing c --- src/serinv/block_primitive/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index e0232472..29eee7fe 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -97,7 +97,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov a1 = _asarray_validated(a, check_finite=check_finite) b1 = _asarray_validated(b, check_finite=check_finite) - if c != None: + if c not None: c1 = _asarray_validated(c, check_finite=check_finite) else: c1 = None From 993c17f5044c992452910f7afc58c0624261387c Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:24:28 +0000 Subject: [PATCH 308/518] fixed c not being able to be true --- src/serinv/block_primitive/gemm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 29eee7fe..61926c1a 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -97,10 +97,10 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov a1 = _asarray_validated(a, check_finite=check_finite) b1 = _asarray_validated(b, check_finite=check_finite) - if c not None: - c1 = _asarray_validated(c, check_finite=check_finite) - else: + if c == None: c1 = None + else: + c1 = _asarray_validated(c, check_finite=check_finite) if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: raise ValueError('expected square matrix') From c3aa5050223382595080ae3627e17ef00854a05f Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:26:02 +0000 Subject: [PATCH 309/518] fixed c again --- src/serinv/block_primitive/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 61926c1a..cce19df2 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -97,7 +97,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov a1 = _asarray_validated(a, check_finite=check_finite) b1 = _asarray_validated(b, check_finite=check_finite) - if c == None: + if c is None: c1 = None else: c1 = _asarray_validated(c, check_finite=check_finite) From a38c39d3e2cbf4e801812703f009cf144cce4d9a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:26:55 +0000 Subject: [PATCH 310/518] further c fix --- src/serinv/block_primitive/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index cce19df2..4fc20b87 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -108,7 +108,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov if a1.shape[0] != b1.shape[0]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') - if beta != 0 and c1 == None: + if beta != 0 and c1 is None: raise ValueError('expected C matrix') # accommodate empty arrays From a873567e2faf095686b7dc06bf3cd498c971d317 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:28:55 +0000 Subject: [PATCH 311/518] second gemm --- src/serinv/algs/pobtaf.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 036d115c..7e81ae41 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -141,7 +141,6 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - gemm( L_lower_diagonal_blocks[i, :, :], L_lower_diagonal_blocks[i, :, :], @@ -152,8 +151,12 @@ def _pobtaf( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - A_lower_arrow_blocks[i + 1, :, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_lower_arrow_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T From 73eeabecd15408478a7ee4bf7654f0802e9149d8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:36:13 +0000 Subject: [PATCH 312/518] removed square matrix check in gemm that was leftover from trsm --- src/serinv/block_primitive/gemm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 4fc20b87..e9d2d668 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -102,8 +102,6 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov else: c1 = _asarray_validated(c, check_finite=check_finite) - if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: - raise ValueError('expected square matrix') if a1.shape[0] != b1.shape[0]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') From 899d7c94a96af4d916bb77560ad5e23c70bedb55 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:39:46 +0000 Subject: [PATCH 313/518] changed input validation --- src/serinv/block_primitive/gemm.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index e9d2d668..052cb6ae 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -102,9 +102,21 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov else: c1 = _asarray_validated(c, check_finite=check_finite) - - if a1.shape[0] != b1.shape[0]: - raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + if not trans_a and not trans_b: + if a1.shape[0] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + elif trans_a and not trans_b: + if a1.shape[1] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + elif not trans_a and trans_b: + if a1.shape[0] != b1.shape[1]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + else: + if a1.shape[1] != b1.shape[1]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') if beta != 0 and c1 is None: raise ValueError('expected C matrix') From e02b484c19fdb5ee3602b62bf42029f67e0512c2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:41:17 +0000 Subject: [PATCH 314/518] third gemm --- src/serinv/algs/pobtaf.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 7e81ae41..57576074 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -161,8 +161,12 @@ def _pobtaf( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block[:, :] = ( - A_arrow_tip_block[:, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_arrow_blocks[i, :, :], + A_arrow_tip_block[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) if factorize_last_block: From b3b8c2218ee7be6f913ac1958db86906d02eb878 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:43:04 +0000 Subject: [PATCH 315/518] full normal pobtaf gemm implemented --- src/serinv/algs/pobtaf.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 57576074..4c3d85b7 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -186,8 +186,12 @@ def _pobtaf( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block[:, :] = ( - A_arrow_tip_block[:, :] - - L_lower_arrow_blocks[-1, :, :] @ L_lower_arrow_blocks[-1, :, :].conj().T + gemm( + L_lower_arrow_blocks[-1, :, :], + L_lower_arrow_blocks[-1, :, :], + A_arrow_tip_block[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) From 0bf8f588088988d3ad3fb455df5f17b41e3df8de Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:50:39 +0000 Subject: [PATCH 316/518] gemm in permuted pobtaf --- src/serinv/algs/pobtaf.py | 46 +++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 4c3d85b7..3631a522 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -257,40 +257,64 @@ def _pobtaf_permuted( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - A_diagonal_blocks[i + 1, :, :] - - L_lower_diagonal_blocks[i, :, :] - @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_diagonal_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_diagonal_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - A_lower_arrow_blocks[i + 1, :, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_lower_arrow_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T L_arrow_tip_block[:, :] = ( - L_arrow_tip_block[:, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_arrow_blocks[i, :, :], + L_arrow_tip_block[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - A_diagonal_blocks[0, :, :] - buffer[i, :, :] @ buffer[i, :, :].conj().T + gemm( + buffer[i, :, :], + buffer[i, :, :], + A_diagonal_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer[i + 1, :, :] = ( - -buffer[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + buffer[i, :, :], + L_lower_diagonal_blocks[i, :, :], + trans_b='C', alpha=-1.0 + ) ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_lower_arrow_blocks[0, :, :] = ( - A_lower_arrow_blocks[0, :, :] - - L_lower_arrow_blocks[i, :, :] @ buffer[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + buffer[i, :, :].conj().T, + A_lower_arrow_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) From 50312bb6780faf9fb8dacd8eb7e5ab3c7adac4ba Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:52:28 +0000 Subject: [PATCH 317/518] rollback to just one gemm --- src/serinv/algs/pobtaf.py | 41 ++++++++++----------------------------- 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 3631a522..858e7edb 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -258,8 +258,8 @@ def _pobtaf_permuted( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( gemm( - L_lower_diagonal_blocks[i, :, :], - L_lower_diagonal_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :] + @ L_lower_diagonal_blocks[i, :, :].conj().T, A_diagonal_blocks[i + 1, :, :], trans_b='C', alpha=-1.0, beta=1.0 ) @@ -267,54 +267,33 @@ def _pobtaf_permuted( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - gemm( - L_lower_arrow_blocks[i, :, :], - L_lower_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i + 1, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_lower_arrow_blocks[i + 1, :, :] + - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T ) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T L_arrow_tip_block[:, :] = ( - gemm( - L_lower_arrow_blocks[i, :, :], - L_lower_arrow_blocks[i, :, :], - L_arrow_tip_block[:, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + L_arrow_tip_block[:, :] + - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T ) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - gemm( - buffer[i, :, :], - buffer[i, :, :], - A_diagonal_blocks[0, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_diagonal_blocks[0, :, :] - buffer[i, :, :] @ buffer[i, :, :].conj().T ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer[i + 1, :, :] = ( - gemm( - buffer[i, :, :], - L_lower_diagonal_blocks[i, :, :], - trans_b='C', alpha=-1.0 - ) + -buffer[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_lower_arrow_blocks[0, :, :] = ( - gemm( - L_lower_arrow_blocks[i, :, :], - buffer[i, :, :].conj().T, - A_lower_arrow_blocks[0, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_lower_arrow_blocks[0, :, :] + - L_lower_arrow_blocks[i, :, :] @ buffer[i, :, :].conj().T ) From 2cf395aa6a3762e0c89bd34302af9a5945bb336a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 11:53:15 +0000 Subject: [PATCH 318/518] removed leftover conj t --- src/serinv/algs/pobtaf.py | 43 +++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 858e7edb..c430c245 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -247,7 +247,7 @@ def _pobtaf_permuted( L_lower_arrow_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i, :, :].conj().T, + A_lower_arrow_blocks[i, :, :], lower=True, ) .conj() @@ -258,8 +258,8 @@ def _pobtaf_permuted( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( gemm( - L_lower_diagonal_blocks[i, :, :] - @ L_lower_diagonal_blocks[i, :, :].conj().T, + L_lower_diagonal_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], A_diagonal_blocks[i + 1, :, :], trans_b='C', alpha=-1.0, beta=1.0 ) @@ -267,33 +267,54 @@ def _pobtaf_permuted( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - A_lower_arrow_blocks[i + 1, :, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_lower_arrow_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T L_arrow_tip_block[:, :] = ( - L_arrow_tip_block[:, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_arrow_blocks[i, :, :], + L_arrow_tip_block[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - A_diagonal_blocks[0, :, :] - buffer[i, :, :] @ buffer[i, :, :].conj().T + gemm( + buffer[i, :, :], + buffer[i, :, :], + A_diagonal_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer[i + 1, :, :] = ( - -buffer[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + buffer[i, :, :], + L_lower_diagonal_blocks[i, :, :], + trans_b='C', alpha=-1.0 + ) ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_lower_arrow_blocks[0, :, :] = ( - A_lower_arrow_blocks[0, :, :] - - L_lower_arrow_blocks[i, :, :] @ buffer[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + buffer[i, :, :].conj().T, + A_lower_arrow_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) From 3b3b244984e3f40356dc2a3cd2a9910e16ae60e2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:47:11 +0000 Subject: [PATCH 319/518] rollback to 1 gemm in permuted --- src/serinv/algs/pobtaf.py | 39 +++++++++------------------------------ 1 file changed, 9 insertions(+), 30 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index c430c245..efa94b0d 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -247,7 +247,7 @@ def _pobtaf_permuted( L_lower_arrow_blocks[i, :, :] = ( trsm( L_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i, :, :], + A_lower_arrow_blocks[i, :, :].conj().T, lower=True, ) .conj() @@ -267,54 +267,33 @@ def _pobtaf_permuted( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - gemm( - L_lower_arrow_blocks[i, :, :], - L_lower_diagonal_blocks[i, :, :], - A_lower_arrow_blocks[i + 1, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_lower_arrow_blocks[i + 1, :, :] + - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T ) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T L_arrow_tip_block[:, :] = ( - gemm( - L_lower_arrow_blocks[i, :, :], - L_lower_arrow_blocks[i, :, :], - L_arrow_tip_block[:, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + L_arrow_tip_block[:, :] + - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T ) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - gemm( - buffer[i, :, :], - buffer[i, :, :], - A_diagonal_blocks[0, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_diagonal_blocks[0, :, :] - buffer[i, :, :] @ buffer[i, :, :].conj().T ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer[i + 1, :, :] = ( - gemm( - buffer[i, :, :], - L_lower_diagonal_blocks[i, :, :], - trans_b='C', alpha=-1.0 - ) + -buffer[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_lower_arrow_blocks[0, :, :] = ( - gemm( - L_lower_arrow_blocks[i, :, :], - buffer[i, :, :].conj().T, - A_lower_arrow_blocks[0, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_lower_arrow_blocks[0, :, :] + - L_lower_arrow_blocks[i, :, :] @ buffer[i, :, :].conj().T ) From 88b3a4cfdf6542530823630db327626546dfe085 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:49:06 +0000 Subject: [PATCH 320/518] next gemm in permuted --- src/serinv/algs/pobtaf.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index efa94b0d..43fde541 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -267,8 +267,12 @@ def _pobtaf_permuted( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - A_lower_arrow_blocks[i + 1, :, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_lower_arrow_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update the block at the tip of the arrowhead From 83f7247bb3e09cd0ef7aa2922997df5ff0db9d29 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:50:14 +0000 Subject: [PATCH 321/518] another gemm --- src/serinv/algs/pobtaf.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 43fde541..a0f1d720 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -278,8 +278,12 @@ def _pobtaf_permuted( # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T L_arrow_tip_block[:, :] = ( - L_arrow_tip_block[:, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_arrow_blocks[i, :, :], + L_arrow_tip_block[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update top and next upper/lower blocks of 2-sided factorization pattern From 7cae093a90d3ef331be7759fd3d8b020cb822b1d Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:51:25 +0000 Subject: [PATCH 322/518] next gemm --- src/serinv/algs/pobtaf.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index a0f1d720..56347f21 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -289,7 +289,12 @@ def _pobtaf_permuted( # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - A_diagonal_blocks[0, :, :] - buffer[i, :, :] @ buffer[i, :, :].conj().T + gemm( + buffer[i, :, :], + buffer[i, :, :], + A_diagonal_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T From 3f889e1220d469292f1de3ab1dd1d7e4da5acc9c Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:52:30 +0000 Subject: [PATCH 323/518] smaller gemm --- src/serinv/algs/pobtaf.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 56347f21..3b41a890 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -299,7 +299,12 @@ def _pobtaf_permuted( # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer[i + 1, :, :] = ( - -buffer[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + buffer[i, :, :], + L_lower_diagonal_blocks[i, :, :], + trans_b='C', alpha=-1.0 + ) + ) # Update the top (first blocks) of the arrowhead From 3c85f7bace8c6fea01ba3e534ddd19263242cb03 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:53:48 +0000 Subject: [PATCH 324/518] last permuted gemm --- src/serinv/algs/pobtaf.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 3b41a890..b00ffb36 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -303,15 +303,18 @@ def _pobtaf_permuted( buffer[i, :, :], L_lower_diagonal_blocks[i, :, :], trans_b='C', alpha=-1.0 - ) - + ) ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_lower_arrow_blocks[0, :, :] = ( - A_lower_arrow_blocks[0, :, :] - - L_lower_arrow_blocks[i, :, :] @ buffer[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + buffer[i, :, :], + A_lower_arrow_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) From ad936cbfcb64927243f909761ca412a6221d7731 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:57:38 +0000 Subject: [PATCH 325/518] first gemm in streaming --- src/serinv/algs/pobtaf.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b00ffb36..81c24a5a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -481,9 +481,12 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T From 9c4f49628ad7b2cc30ad989ed861446c6ac8a185 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:58:43 +0000 Subject: [PATCH 326/518] second gemm streaming --- src/serinv/algs/pobtaf.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 81c24a5a..3afb3a60 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -491,9 +491,12 @@ def _pobtaf_streaming( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - A_lower_arrow_blocks_d[(i + 1) % 2, :, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_lower_arrow_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) compute_lower_h2d_events[i % 2].record(stream=compute_stream) From e0e5bdd040382e5f3834090fec730e661cf69991 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 12:59:51 +0000 Subject: [PATCH 327/518] third gemm streaming --- src/serinv/algs/pobtaf.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 3afb3a60..8d29b1a7 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -502,9 +502,12 @@ def _pobtaf_streaming( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_arrow_blocks_d[i % 2, :, :], + A_arrow_tip_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) compute_arrow_h2d_events[i % 2].record(stream=compute_stream) From 4bb8e6ef0eba8d758abe8437013b0d4bb7e82549 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 13:04:48 +0000 Subject: [PATCH 328/518] two permuted streaming gemms --- src/serinv/algs/pobtaf.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 8d29b1a7..83104526 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -553,9 +553,12 @@ def _pobtaf_streaming( if factorize_last_block: # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] - @ L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], + A_arrow_tip_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) @@ -774,16 +777,22 @@ def _pobtaf_permuted_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - A_lower_arrow_blocks_d[(i + 1) % 2, :, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_lower_arrow_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T From 480d982b5ab2138a4dc286c22c1b7e6ac59ac48e Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 10 Jun 2025 13:08:47 +0000 Subject: [PATCH 329/518] implemented gemms for permuted streaming --- src/serinv/algs/pobtaf.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 83104526..0975d58a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -797,34 +797,46 @@ def _pobtaf_permuted_streaming( # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer_d[(i + 1) % 2, :, :] = ( - -L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + trans_b='C', alpha=-1.0 + ) ) cp_lower_events_h2d_release[i % 2].record(stream=compute_stream) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_arrow_blocks_d[i % 2, :, :], + A_arrow_tip_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_arrow_bottom_top_block_d[:, :] = ( - A_arrow_bottom_top_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_upper_nested_dissection_buffer_d[i % 2, :, :], + A_arrow_bottom_top_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) cp_arrow_events_h2d_release[i % 2].record(stream=compute_stream) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_top_block_d[:, :] = ( - A_diagonal_top_block_d[:, :] - - L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T + gemm( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + L_upper_nested_dissection_buffer_d[i % 2, :, :], + A_diagonal_top_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # --- Device 2 Host transfers --- From bfadd081db9e57572abe9d19ef105d1d234d5c82 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 08:30:07 +0000 Subject: [PATCH 330/518] implemented a form of syrk/herk and added a error for testing --- src/serinv/algs/pobtaf.py | 3 +- src/serinv/block_primitive/gemm.py | 85 ++++------------------------ src/serinv/block_primitive/syherk.py | 71 +++++++++++++++++++++++ src/serinv/block_primitive/trsm.py | 9 +-- 4 files changed, 87 insertions(+), 81 deletions(-) create mode 100644 src/serinv/block_primitive/syherk.py diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 0975d58a..b81680b3 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -148,7 +148,8 @@ def _pobtaf( trans_b='C', alpha=-1.0, beta=1.0 ) ) - + print(A_diagonal_blocks[i + 1, :, :]) + raise ValueError("TEST") # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( gemm( diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 052cb6ae..8924cb2b 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -28,71 +28,12 @@ def gemm (a, b, c=None, alpha=1.0, beta=0.0, trans_a ='N', trans_b ='N'): def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, overwrite_c=0, check_finite=False): - """ - Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. - - Parameters - ---------- - a : (M, M) array_like - A triangular matrix - b : (M,) or (M, N) array_like - Right-hand side matrix in ``a x = b`` - lower : bool, optional - Use only data contained in the lower triangle of `a`. - Default is to use upper triangle. - trans : {0, 1, 2, 'N', 'T', 'C'}, optional - Type of system to solve: - - ======== ========= - trans system - ======== ========= - 0 or 'N' a x = b - 1 or 'T' a^T x = b - 2 or 'C' a^H x = b - ======== ========= - unit_diagonal : bool, optional - If True, diagonal elements of `a` are assumed to be 1 and - will not be referenced. - overwrite_b : bool, optional - Allow overwriting data in `b` (may enhance performance) - check_finite : bool, optional - Whether to check that the input matrices contain only finite numbers. - Disabling may give a performance gain, but may result in problems - (crashes, non-termination) if the inputs do contain infinities or NaNs. - - Returns - ------- - x : (M,) or (M, N) ndarray - Solution to the system ``a x = b``. Shape of return matches `b`. - - Raises - ------ - LinAlgError - If `a` is singular - - Notes - ----- - .. versionadded:: 0.9.0 - - Examples - -------- - Solve the lower triangular system a x = b, where:: - - [3 0 0 0] [4] - a = [2 1 0 0] b = [2] - [1 0 1 0] [4] - [1 1 1 1] [2] - - >>> import numpy as np - >>> from scipy.linalg import solve_triangular - >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]]) - >>> b = np.array([4, 2, 4, 2]) - >>> x = solve_triangular(a, b, lower=True) - >>> x - array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333]) - >>> a.dot(x) # Check the result - array([ 4., 2., 4., 2.]) + """Computes out = alpha * op(a) @ op(b) + beta * out + op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', + op(a) = a.T.conj() if transa is 'C'. + op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', + op(b) = b.T.conj() if transb is 'C'. """ a1 = _asarray_validated(a, check_finite=check_finite) @@ -128,29 +69,27 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov ).dtype return np.empty_like(b1, dtype=dt_nonempty) + if c1 is not None: + overwrite_c = overwrite_c or _datacopied(c1, c) + x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) return x -# solve_triangular without the input validation +# gemm without the input validation def _matmul_gemm(a1, b1, alpha=1.0, beta=0.0, c1=None, trans_a=0, trans_b=0, overwrite_c=0): trans_a = {'N': 0, 'T': 1, 'C': 2}.get(trans_a, trans_a) trans_b = {'N': 0, 'T': 1, 'C': 2}.get(trans_b, trans_b) gemm, = get_blas_funcs(('gemm',), (a1, b1)) - if a1.dtype.char in 'fd': - dtype = a1.dtype - else: - dtype = np.promote_types(a1.dtype.char, 'f') - if beta == 0: - x = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) + out = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) else: - x = gemm(alpha, a1, b1, beta, c1, trans_a, trans_b, overwrite_c) + out = gemm(alpha, a1, b1, beta, c1, trans_a, trans_b, overwrite_c) - return x + return out # Util functions for cupy gemm diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py new file mode 100644 index 00000000..d708ff11 --- /dev/null +++ b/src/serinv/block_primitive/syherk.py @@ -0,0 +1,71 @@ +from serinv import _get_module_from_array + +import numpy as np +from numpy.linalg import matmul + +from scipy.linalg.blas import get_blas_funcs +from scipy.linalg._misc import _datacopied +from scipy.linalg._decomp import _asarray_validated + +try: + import cupy as cp + from cupy_backends.cuda.libs import cublas + from cupy import _core + from cupy.cuda import device +except (ImportError, ImportWarning, ModuleNotFoundError): + pass + +def syherk(a, c=None, trans=0, lower = False, unit_diagonal=False, + overwrite_c=False, check_finite=False, side=0): + """Wrapper for the trsm function to call depending on wheter the solve happens on the host or the device + + For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept + plus the side parameter which can either be 0 or 1 for left or right hand side + """ + xp, la = _get_module_from_array(a) + if xp == np: + return matmul_syherk_host(a, c, trans, lower, unit_diagonal, overwrite_c, check_finite, side) + elif xp == cp: + return matmul_syherk_device(a, c, trans, lower, unit_diagonal, overwrite_c, check_finite, side) + else: + ModuleNotFoundError("Unknown Module") + +def matmul_syherk_host(a, c=None, alpha=1.0, beta=1.0, trans=0, lower=False, unit_diagonal=False, + overwrite_c=False, check_finite=True, side=0): + """Computes out = alpha * op(a) @ op(a)^T + beta * b + + op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', + op(a) = a.T.conj() if transa is 'C'. + """ + + a1 = _asarray_validated(a, check_finite=check_finite) + if c is None: + c1 = None + else: + c1 = _asarray_validated(c, check_finite=check_finite) + + + if a1.shape[0] != c1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and c {c1.shape} are incompatible') + + + overwrite_c = overwrite_c or _datacopied(c1, c) + + x = _syherk(a1, c1, alpha, beta, trans, lower, unit_diagonal, overwrite_c, side) + return x + + +# syherk without the input validation +def _syherk(a1, c1=None, alpha=1.0, beta=0.0, trans=0, lower=False, unit_diagonal=False, + overwrite_c=False, side=0): + + trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) + + if np.iscomplexobj(a1): + syherk, = get_blas_funcs(('herk'), (a1, c1)) + else: + syherk, = get_blas_funcs(('syrk'), (a1, c1)) + + out = syherk(alpha, a1, beta, c1, trans, lower, overwrite_c) + + return out \ No newline at end of file diff --git a/src/serinv/block_primitive/trsm.py b/src/serinv/block_primitive/trsm.py index 5f447e27..0185bed9 100644 --- a/src/serinv/block_primitive/trsm.py +++ b/src/serinv/block_primitive/trsm.py @@ -306,18 +306,13 @@ def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) trsm, = get_blas_funcs(('trsm',), (a1, b1)) - if a1.dtype.char in 'fd': - dtype = a1.dtype - else: - dtype = np.promote_types(a1.dtype.char, 'f') - - alpha = 1 + alpha = 1.0 if a1.flags.f_contiguous or trans == 2: x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=lower, trans_a=trans, diag=unit_diagonal, side=side) else: - # transposed system is solved since trtrs expects Fortran ordering + # transposed system is solved since trsm expects Fortran ordering x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, trans_a=not trans, diag=unit_diagonal, side=side) From 4e6bbde3ea75a3c70a922c5a095e4a0834b04060 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 08:30:55 +0000 Subject: [PATCH 331/518] added another print for debug --- src/serinv/algs/pobtaf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b81680b3..06f7731f 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -140,6 +140,7 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T + print(A_diagonal_blocks[i + 1, :, :]) A_diagonal_blocks[i + 1, :, :] = ( gemm( L_lower_diagonal_blocks[i, :, :], From f8ccb36ffe94e05e5a2cb938a538a33ab715b721 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 08:33:09 +0000 Subject: [PATCH 332/518] added another print --- src/serinv/algs/pobtaf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 06f7731f..e6ad9b74 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -141,6 +141,7 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T print(A_diagonal_blocks[i + 1, :, :]) + print(L_lower_diagonal_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T) A_diagonal_blocks[i + 1, :, :] = ( gemm( L_lower_diagonal_blocks[i, :, :], From 256e3d0c2d57ae5a4acfed3d1aaa1901d64cb7d6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:01:36 +0000 Subject: [PATCH 333/518] implemented syherk. sadly it's not yet useful --- src/serinv/block_primitive/__init__.py | 4 +- src/serinv/block_primitive/syherk.py | 139 ++++++++++++++++++++++--- 2 files changed, 130 insertions(+), 13 deletions(-) diff --git a/src/serinv/block_primitive/__init__.py b/src/serinv/block_primitive/__init__.py index b8cf0541..4d3895ec 100644 --- a/src/serinv/block_primitive/__init__.py +++ b/src/serinv/block_primitive/__init__.py @@ -1,7 +1,9 @@ from serinv.block_primitive.gemm import gemm from serinv.block_primitive.trsm import trsm +from serinv.block_primitive.trsm import syherk __all__ = [ "gemm", - "trsm" + "trsm", + "syherk" ] \ No newline at end of file diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index d708ff11..388a6b5f 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -10,13 +10,11 @@ try: import cupy as cp from cupy_backends.cuda.libs import cublas - from cupy import _core from cupy.cuda import device except (ImportError, ImportWarning, ModuleNotFoundError): pass -def syherk(a, c=None, trans=0, lower = False, unit_diagonal=False, - overwrite_c=False, check_finite=False, side=0): +def syherk(a, c=None, alpha=1.0, beta=0.0, trans=0, lower = False): """Wrapper for the trsm function to call depending on wheter the solve happens on the host or the device For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept @@ -24,9 +22,9 @@ def syherk(a, c=None, trans=0, lower = False, unit_diagonal=False, """ xp, la = _get_module_from_array(a) if xp == np: - return matmul_syherk_host(a, c, trans, lower, unit_diagonal, overwrite_c, check_finite, side) + return matmul_syherk_host(a, c, alpha, beta, trans, lower) elif xp == cp: - return matmul_syherk_device(a, c, trans, lower, unit_diagonal, overwrite_c, check_finite, side) + return matmul_syherk_device(a, trans, c, alpha, beta, lower) else: ModuleNotFoundError("Unknown Module") @@ -43,11 +41,6 @@ def matmul_syherk_host(a, c=None, alpha=1.0, beta=1.0, trans=0, lower=False, uni c1 = None else: c1 = _asarray_validated(c, check_finite=check_finite) - - - if a1.shape[0] != c1.shape[0]: - raise ValueError(f'shapes of a {a1.shape} and c {c1.shape} are incompatible') - overwrite_c = overwrite_c or _datacopied(c1, c) @@ -56,8 +49,8 @@ def matmul_syherk_host(a, c=None, alpha=1.0, beta=1.0, trans=0, lower=False, uni # syherk without the input validation -def _syherk(a1, c1=None, alpha=1.0, beta=0.0, trans=0, lower=False, unit_diagonal=False, - overwrite_c=False, side=0): +def _syherk(a1, c1=None, alpha=1.0, beta=0.0, trans=0, lower=False, + overwrite_c=False): trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) @@ -68,4 +61,126 @@ def _syherk(a1, c1=None, alpha=1.0, beta=0.0, trans=0, lower=False, unit_diagona out = syherk(alpha, a1, beta, c1, trans, lower, overwrite_c) + return out + + + +# Util functions for cupy gemm +def _trans_to_cublas_op(trans): + if trans == 'N' or trans == cublas.CUBLAS_OP_N: + trans = cublas.CUBLAS_OP_N + elif trans == 'T' or trans == cublas.CUBLAS_OP_T: + trans = cublas.CUBLAS_OP_T + elif trans == 'C' or trans == cublas.CUBLAS_OP_C: + trans = cublas.CUBLAS_OP_C + else: + raise TypeError('invalid trans (actual: {})'.format(trans)) + return trans + +def _decide_ld_and_trans(a, trans): + ld = None + if trans in (cublas.CUBLAS_OP_N, cublas.CUBLAS_OP_T): + if a._f_contiguous: + ld = a.shape[0] + elif a._c_contiguous: + ld = a.shape[1] + trans = 1 - trans + return ld, trans + +def _get_scalar_ptr(a, dtype): + if isinstance(a, cp.ndarray): + if a.dtype != dtype: + a = cp.array(a, dtype=dtype) + a_ptr = a.data.ptr + else: + if not (isinstance(a, np.ndarray) and a.dtype == dtype): + a = np.array(a, dtype=dtype) + a_ptr = a.ctypes.data + return a, a_ptr +# Util functions for cupy gemm end + +def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=False): + """Computes out := alpha*op1(a)*op2(a) + beta*out + + op1(a) = a if trans is 'N', op2(a) = a.T if transa is 'N' + op1(a) = a.T if trans is 'T', op2(a) = a if transa is 'T' + lower specifies whether the upper or lower triangular + part of the array out is to be referenced + """ + assert a.ndim == 2 + dtype = a.dtype.char + if dtype == 'f': + func = cublas.ssyrk + elif dtype == 'd': + func = cublas.dsyrk + elif dtype == 'F': + func = cublas.cherk + elif dtype == 'D': + func = cublas.zherk + else: + raise TypeError('invalid dtype') + + trans = _trans_to_cublas_op(trans) + if trans == cublas.CUBLAS_OP_N: + n, k = a.shape + else: + k, n = a.shape + if out is None: + out = cp.zeros((n, n), dtype=dtype, order='F') + beta = 0.0 + else: + assert out.ndim == 2 + assert out.shape == (n, n) + assert out.dtype == dtype + + if lower: + uplo = cublas.CUBLAS_FILL_MODE_LOWER + else: + uplo = cublas.CUBLAS_FILL_MODE_UPPER + + alpha, alpha_ptr = _get_scalar_ptr(alpha, a.dtype) + beta, beta_ptr = _get_scalar_ptr(beta, a.dtype) + handle = device.get_cublas_handle() + orig_mode = cublas.getPointerMode(handle) + if isinstance(alpha, cp.ndarray) or isinstance(beta, cp.ndarray): + if not isinstance(alpha, cp.ndarray): + alpha = cp.array(alpha) + alpha_ptr = alpha.data.ptr + if not isinstance(beta, cp.ndarray): + beta = cp.array(beta) + beta_ptr = beta.data.ptr + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_DEVICE) + else: + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_HOST) + + lda, trans = _decide_ld_and_trans(a, trans) + ldo, _ = _decide_ld_and_trans(out, trans) + if out._c_contiguous: + if not a._c_contiguous: + a = a.copy(order='C') + trans = 1 - trans + lda = a.shape[1] + try: + func(handle, 1 - uplo, trans, n, k, + alpha_ptr, a.data.ptr, lda, + beta_ptr, out.data.ptr, ldo) + finally: + cublas.setPointerMode(handle, orig_mode) + + else: + if not a._f_contiguous: + a = a.copy(order='F') + lda = a.shape[0] + trans = 1 - trans + c = out + if not out._f_contiguous: + c = out.copy(order='F') + try: + func(handle, uplo, trans, n, k, + alpha_ptr, a.data.ptr, lda, + beta_ptr, out.data.ptr, ldo) + finally: + cublas.setPointerMode(handle, orig_mode) + if not out._f_contiguous: + out[...] = c return out \ No newline at end of file From 9d8ea282eb507659ad8b77863cfd3c0b35cb3ec8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:04:53 +0000 Subject: [PATCH 334/518] removed debug prints --- src/serinv/algs/pobtaf.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index e6ad9b74..0975d58a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -140,8 +140,6 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T - print(A_diagonal_blocks[i + 1, :, :]) - print(L_lower_diagonal_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T) A_diagonal_blocks[i + 1, :, :] = ( gemm( L_lower_diagonal_blocks[i, :, :], @@ -150,8 +148,7 @@ def _pobtaf( trans_b='C', alpha=-1.0, beta=1.0 ) ) - print(A_diagonal_blocks[i + 1, :, :]) - raise ValueError("TEST") + # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( gemm( From 588533817e416ff4d025b6eb76f1da40c18dbec5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:33:22 +0000 Subject: [PATCH 335/518] added test error --- src/serinv/algs/pobtaf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 0975d58a..8e98793f 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -256,6 +256,7 @@ def _pobtaf_permuted( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T + print(L_lower_diagonal_blocks[i, :, :]) A_diagonal_blocks[i + 1, :, :] = ( gemm( L_lower_diagonal_blocks[i, :, :], @@ -264,7 +265,7 @@ def _pobtaf_permuted( trans_b='C', alpha=-1.0, beta=1.0 ) ) - + raise ValueError("TEST") # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( gemm( From 1436b7f149d4d6c085a650ae2f45ba452099d3ea Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:33:58 +0000 Subject: [PATCH 336/518] fixed typo --- src/serinv/block_primitive/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/__init__.py b/src/serinv/block_primitive/__init__.py index 4d3895ec..89b14a73 100644 --- a/src/serinv/block_primitive/__init__.py +++ b/src/serinv/block_primitive/__init__.py @@ -1,6 +1,6 @@ from serinv.block_primitive.gemm import gemm from serinv.block_primitive.trsm import trsm -from serinv.block_primitive.trsm import syherk +from serinv.block_primitive.syherk import syherk __all__ = [ "gemm", From b8a5a25be5c6dd5a4b4420cef335aa18f049caa8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:37:04 +0000 Subject: [PATCH 337/518] moved test --- src/serinv/algs/pobtaf.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 8e98793f..73583a25 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -126,6 +126,7 @@ def _pobtaf( ) ) + # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( @@ -148,7 +149,7 @@ def _pobtaf( trans_b='C', alpha=-1.0, beta=1.0 ) ) - + print(A_diagonal_blocks[i + 1, :, :]) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( gemm( @@ -168,6 +169,8 @@ def _pobtaf( trans_b='C', alpha=-1.0, beta=1.0 ) ) + print(A_arrow_tip_block[:, :]) + raise ValueError("TEST") if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) @@ -256,7 +259,6 @@ def _pobtaf_permuted( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T - print(L_lower_diagonal_blocks[i, :, :]) A_diagonal_blocks[i + 1, :, :] = ( gemm( L_lower_diagonal_blocks[i, :, :], @@ -265,7 +267,6 @@ def _pobtaf_permuted( trans_b='C', alpha=-1.0, beta=1.0 ) ) - raise ValueError("TEST") # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( gemm( From 2bf27a49f2615a2cb293e93c983fdcf7b3048dc9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:43:19 +0000 Subject: [PATCH 338/518] attempt for using syherk --- src/serinv/algs/pobtaf.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 73583a25..96a8edc0 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -7,7 +7,7 @@ _get_cholesky, ) -from serinv.block_primitive import trsm, gemm +from serinv.block_primitive import trsm, gemm, syherk def pobtaf( A_diagonal_blocks: ArrayLike, @@ -170,7 +170,6 @@ def _pobtaf( ) ) print(A_arrow_tip_block[:, :]) - raise ValueError("TEST") if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) @@ -188,14 +187,21 @@ def _pobtaf( ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} + # A_arrow_tip_block[:, :] = ( + # gemm( + # L_lower_arrow_blocks[-1, :, :], + # L_lower_arrow_blocks[-1, :, :], + # A_arrow_tip_block[:, :], + # trans_b='C', alpha=-1.0, beta=1.0 + # ) + # ) + A_arrow_tip_block[:, :] = ( - gemm( - L_lower_arrow_blocks[-1, :, :], + syherk( L_lower_arrow_blocks[-1, :, :], A_arrow_tip_block[:, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True ) - ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) From e62d9ff84ad24234af68ca10fa31ccdcdd30c60e Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:43:44 +0000 Subject: [PATCH 339/518] missing parenthesis --- src/serinv/algs/pobtaf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 96a8edc0..b7fb8a06 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -202,6 +202,7 @@ def _pobtaf( A_arrow_tip_block[:, :], alpha=-1.0, beta=1.0, lower=True ) + ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) From bc52ddc6929e962562d7112d3c6b7e5f8635cebc Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:45:22 +0000 Subject: [PATCH 340/518] changed input for _syherk --- src/serinv/block_primitive/syherk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 388a6b5f..90085098 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -44,7 +44,7 @@ def matmul_syherk_host(a, c=None, alpha=1.0, beta=1.0, trans=0, lower=False, uni overwrite_c = overwrite_c or _datacopied(c1, c) - x = _syherk(a1, c1, alpha, beta, trans, lower, unit_diagonal, overwrite_c, side) + x = _syherk(a1, c1, alpha, beta, trans, lower, overwrite_c) return x From cf3f7221e98eeaee9ddd86c9065463d1d3af6cfc Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 11 Jun 2025 09:46:08 +0000 Subject: [PATCH 341/518] removed iteration error --- src/serinv/block_primitive/syherk.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 90085098..4478c1fb 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -55,9 +55,9 @@ def _syherk(a1, c1=None, alpha=1.0, beta=0.0, trans=0, lower=False, trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) if np.iscomplexobj(a1): - syherk, = get_blas_funcs(('herk'), (a1, c1)) + syherk = get_blas_funcs(('herk'), (a1, c1)) else: - syherk, = get_blas_funcs(('syrk'), (a1, c1)) + syherk = get_blas_funcs(('syrk'), (a1, c1)) out = syherk(alpha, a1, beta, c1, trans, lower, overwrite_c) From 8fcdf5f25ac46112e5e94a312daf72bd1cd3dc25 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:04:30 +0000 Subject: [PATCH 342/518] changed implementation of syherk to not use cherk and zherk if not available on cupy --- src/serinv/block_primitive/syherk.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 4478c1fb..c63b8016 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,5 +1,7 @@ from serinv import _get_module_from_array +from serinv.block_primitive import gemm + import numpy as np from numpy.linalg import matmul @@ -114,9 +116,17 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals elif dtype == 'd': func = cublas.dsyrk elif dtype == 'F': - func = cublas.cherk + try: + func = cublas.cherk + except(AttributeError): + out = gemm(a, a, out, trans_b='C', alpha=alpha, beta=beta, lower=lower) + return out elif dtype == 'D': - func = cublas.zherk + try: + func = cublas.zherk + except(AttributeError): + out = gemm(a, a, out, trans_b='C', alpha=alpha, beta=beta, lower=lower) + return out else: raise TypeError('invalid dtype') From 4aedc59e32ff8eac26ca8464bc9fd2325457953b Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:05:11 +0000 Subject: [PATCH 343/518] fixed gemm call in syherk --- src/serinv/block_primitive/syherk.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index c63b8016..c28ba247 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -119,13 +119,13 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals try: func = cublas.cherk except(AttributeError): - out = gemm(a, a, out, trans_b='C', alpha=alpha, beta=beta, lower=lower) + out = gemm(a, a, out, trans_b='C', alpha=alpha, beta=beta) return out elif dtype == 'D': try: func = cublas.zherk except(AttributeError): - out = gemm(a, a, out, trans_b='C', alpha=alpha, beta=beta, lower=lower) + out = gemm(a, a, out, trans_b='C', alpha=alpha, beta=beta) return out else: raise TypeError('invalid dtype') From fd73f46bfec38d3927cff172162d1362eef75bb5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:09:09 +0000 Subject: [PATCH 344/518] added debug print --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b7fb8a06..54c941e5 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -169,7 +169,6 @@ def _pobtaf( trans_b='C', alpha=-1.0, beta=1.0 ) ) - print(A_arrow_tip_block[:, :]) if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) @@ -203,6 +202,7 @@ def _pobtaf( alpha=-1.0, beta=1.0, lower=True ) ) + print(A_arrow_tip_block[:, :]) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) From 317a7cfa171bc73e77757584014135c7393d57aa Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:10:18 +0000 Subject: [PATCH 345/518] more debug print --- src/serinv/algs/pobtaf.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 54c941e5..3f92119f 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -202,6 +202,10 @@ def _pobtaf( alpha=-1.0, beta=1.0, lower=True ) ) + print(syherk( + L_lower_arrow_blocks[-1, :, :], + alpha=-1.0, beta=1.0, lower=True + )) print(A_arrow_tip_block[:, :]) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) From 4308160fbb54c7de320d3e9185ffce2f3fd90b6b Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:16:05 +0000 Subject: [PATCH 346/518] changed call to not include c1 --- src/serinv/algs/pobtaf.py | 4 ---- src/serinv/block_primitive/syherk.py | 4 ++-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 3f92119f..54c941e5 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -202,10 +202,6 @@ def _pobtaf( alpha=-1.0, beta=1.0, lower=True ) ) - print(syherk( - L_lower_arrow_blocks[-1, :, :], - alpha=-1.0, beta=1.0, lower=True - )) print(A_arrow_tip_block[:, :]) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index c28ba247..e5b72ac1 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -57,9 +57,9 @@ def _syherk(a1, c1=None, alpha=1.0, beta=0.0, trans=0, lower=False, trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) if np.iscomplexobj(a1): - syherk = get_blas_funcs(('herk'), (a1, c1)) + syherk = get_blas_funcs(('herk'), (a1, a1)) else: - syherk = get_blas_funcs(('syrk'), (a1, c1)) + syherk = get_blas_funcs(('syrk'), (a1, a1)) out = syherk(alpha, a1, beta, c1, trans, lower, overwrite_c) From d9e1798bd6fbb77b1ad6702ea7546a5b27f2a0d5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:16:45 +0000 Subject: [PATCH 347/518] added print --- src/serinv/algs/pobtaf.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 54c941e5..3f92119f 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -202,6 +202,10 @@ def _pobtaf( alpha=-1.0, beta=1.0, lower=True ) ) + print(syherk( + L_lower_arrow_blocks[-1, :, :], + alpha=-1.0, beta=1.0, lower=True + )) print(A_arrow_tip_block[:, :]) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) From d8e2557c534c503fff56c47163ac3737de057800 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:18:28 +0000 Subject: [PATCH 348/518] more print --- src/serinv/algs/pobtaf.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 3f92119f..8b6084bc 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -202,11 +202,21 @@ def _pobtaf( alpha=-1.0, beta=1.0, lower=True ) ) + print("#") print(syherk( L_lower_arrow_blocks[-1, :, :], alpha=-1.0, beta=1.0, lower=True )) + print("#") print(A_arrow_tip_block[:, :]) + print("#") + print(gemm( + L_lower_arrow_blocks[-1, :, :], + L_lower_arrow_blocks[-1, :, :], + A_arrow_tip_block[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + )) + print("####") # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) From 2e03d5fa1021c834d9856a48e6aa5e56ca71e191 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:21:59 +0000 Subject: [PATCH 349/518] added different prints --- src/serinv/algs/pobtaf.py | 16 ---------------- src/serinv/block_primitive/syherk.py | 2 ++ 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 8b6084bc..f4cc712b 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -149,7 +149,6 @@ def _pobtaf( trans_b='C', alpha=-1.0, beta=1.0 ) ) - print(A_diagonal_blocks[i + 1, :, :]) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( gemm( @@ -202,21 +201,6 @@ def _pobtaf( alpha=-1.0, beta=1.0, lower=True ) ) - print("#") - print(syherk( - L_lower_arrow_blocks[-1, :, :], - alpha=-1.0, beta=1.0, lower=True - )) - print("#") - print(A_arrow_tip_block[:, :]) - print("#") - print(gemm( - L_lower_arrow_blocks[-1, :, :], - L_lower_arrow_blocks[-1, :, :], - A_arrow_tip_block[:, :], - trans_b='C', alpha=-1.0, beta=1.0 - )) - print("####") # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index e5b72ac1..7ce911d8 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -174,6 +174,7 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals func(handle, 1 - uplo, trans, n, k, alpha_ptr, a.data.ptr, lda, beta_ptr, out.data.ptr, ldo) + print("yes") finally: cublas.setPointerMode(handle, orig_mode) @@ -189,6 +190,7 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals func(handle, uplo, trans, n, k, alpha_ptr, a.data.ptr, lda, beta_ptr, out.data.ptr, ldo) + print("yes") finally: cublas.setPointerMode(handle, orig_mode) if not out._f_contiguous: From 469eb6750eba2c35a01ad16abff935f59bb0293f Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:23:55 +0000 Subject: [PATCH 350/518] more debug prints --- src/serinv/block_primitive/syherk.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 7ce911d8..6a3c9b1d 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -165,6 +165,13 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals lda, trans = _decide_ld_and_trans(a, trans) ldo, _ = _decide_ld_and_trans(out, trans) + + print(a) + print(out) + print(alpha) + print(beta) + print(lower) + if out._c_contiguous: if not a._c_contiguous: a = a.copy(order='C') @@ -174,7 +181,7 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals func(handle, 1 - uplo, trans, n, k, alpha_ptr, a.data.ptr, lda, beta_ptr, out.data.ptr, ldo) - print("yes") + print("yes1") finally: cublas.setPointerMode(handle, orig_mode) @@ -190,9 +197,10 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals func(handle, uplo, trans, n, k, alpha_ptr, a.data.ptr, lda, beta_ptr, out.data.ptr, ldo) - print("yes") + print("yes2") finally: cublas.setPointerMode(handle, orig_mode) if not out._f_contiguous: out[...] = c + print(out) return out \ No newline at end of file From 0ed638224391d1785376da21d53ceed6448047fb Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:26:22 +0000 Subject: [PATCH 351/518] new debug --- src/serinv/algs/pobtaf.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index f4cc712b..0e5e36d9 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -193,6 +193,11 @@ def _pobtaf( # trans_b='C', alpha=-1.0, beta=1.0 # ) # ) + a = syherk( + L_lower_arrow_blocks[-1, :, :], + alpha=-1.0, beta=0.0, lower=True + ) + print(a) A_arrow_tip_block[:, :] = ( syherk( From 11c257ecd636637e4693443d24e670fbd7793802 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:32:01 +0000 Subject: [PATCH 352/518] further print --- src/serinv/algs/pobtaf.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 0e5e36d9..3cbd24fb 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -193,6 +193,12 @@ def _pobtaf( # trans_b='C', alpha=-1.0, beta=1.0 # ) # ) + b = gemm( + L_lower_arrow_blocks[-1, :, :], + L_lower_arrow_blocks[-1, :, :], + trans_b='C', alpha=-1.0, beta=0.0 + ) + print(b) a = syherk( L_lower_arrow_blocks[-1, :, :], alpha=-1.0, beta=0.0, lower=True From 83b54e1b4ffeab77d7496a96feeb348815a823f1 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:34:23 +0000 Subject: [PATCH 353/518] removed some prints --- src/serinv/block_primitive/syherk.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 6a3c9b1d..2ecde792 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -166,11 +166,11 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals lda, trans = _decide_ld_and_trans(a, trans) ldo, _ = _decide_ld_and_trans(out, trans) - print(a) - print(out) - print(alpha) - print(beta) - print(lower) + #print(a) + #print(out) + #print(alpha) + #print(beta) + #print(lower) if out._c_contiguous: if not a._c_contiguous: @@ -202,5 +202,5 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals cublas.setPointerMode(handle, orig_mode) if not out._f_contiguous: out[...] = c - print(out) + #print(out) return out \ No newline at end of file From af61f321f662d20c6352ac664a337c09f854ad08 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:36:05 +0000 Subject: [PATCH 354/518] more print --- src/serinv/algs/pobtaf.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 3cbd24fb..2907d3ae 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -193,14 +193,18 @@ def _pobtaf( # trans_b='C', alpha=-1.0, beta=1.0 # ) # ) + c = A_arrow_tip_block[:, :] b = gemm( L_lower_arrow_blocks[-1, :, :], L_lower_arrow_blocks[-1, :, :], + c, trans_b='C', alpha=-1.0, beta=0.0 ) print(b) + c = A_arrow_tip_block[:, :] a = syherk( L_lower_arrow_blocks[-1, :, :], + c, alpha=-1.0, beta=0.0, lower=True ) print(a) From 2a8029315d63bb676159ff82122f63a3afc822a4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:39:41 +0000 Subject: [PATCH 355/518] changed print --- src/serinv/algs/pobtaf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 2907d3ae..4293a1ee 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -193,6 +193,7 @@ def _pobtaf( # trans_b='C', alpha=-1.0, beta=1.0 # ) # ) + c = A_arrow_tip_block[:, :] b = gemm( L_lower_arrow_blocks[-1, :, :], @@ -205,7 +206,7 @@ def _pobtaf( a = syherk( L_lower_arrow_blocks[-1, :, :], c, - alpha=-1.0, beta=0.0, lower=True + alpha=-1.0, beta=1.0, lower=True ) print(a) From e0a54b4d0d3cb78661459c67e3cc3e3516bb1286 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:40:59 +0000 Subject: [PATCH 356/518] changed prints again --- src/serinv/block_primitive/syherk.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 2ecde792..bebb796f 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -166,8 +166,8 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals lda, trans = _decide_ld_and_trans(a, trans) ldo, _ = _decide_ld_and_trans(out, trans) - #print(a) - #print(out) + print(a) + print(out) #print(alpha) #print(beta) #print(lower) From 154d8c7ae9599b3031c5e65ba4d21261436d5777 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:43:04 +0000 Subject: [PATCH 357/518] reverted some prints --- src/serinv/algs/pobtaf.py | 16 ---------------- src/serinv/block_primitive/syherk.py | 10 +++++----- 2 files changed, 5 insertions(+), 21 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 4293a1ee..f4cc712b 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -194,22 +194,6 @@ def _pobtaf( # ) # ) - c = A_arrow_tip_block[:, :] - b = gemm( - L_lower_arrow_blocks[-1, :, :], - L_lower_arrow_blocks[-1, :, :], - c, - trans_b='C', alpha=-1.0, beta=0.0 - ) - print(b) - c = A_arrow_tip_block[:, :] - a = syherk( - L_lower_arrow_blocks[-1, :, :], - c, - alpha=-1.0, beta=1.0, lower=True - ) - print(a) - A_arrow_tip_block[:, :] = ( syherk( L_lower_arrow_blocks[-1, :, :], diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index bebb796f..bba6295c 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -30,7 +30,7 @@ def syherk(a, c=None, alpha=1.0, beta=0.0, trans=0, lower = False): else: ModuleNotFoundError("Unknown Module") -def matmul_syherk_host(a, c=None, alpha=1.0, beta=1.0, trans=0, lower=False, unit_diagonal=False, +def matmul_syherk_host(a, c=None, alpha=1.0, beta=1.0, trans=0, lower=False, overwrite_c=False, check_finite=True, side=0): """Computes out = alpha * op(a) @ op(a)^T + beta * b @@ -168,9 +168,9 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals print(a) print(out) - #print(alpha) - #print(beta) - #print(lower) + print(alpha) + print(beta) + print(lower) if out._c_contiguous: if not a._c_contiguous: @@ -202,5 +202,5 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals cublas.setPointerMode(handle, orig_mode) if not out._f_contiguous: out[...] = c - #print(out) + print(out) return out \ No newline at end of file From 2759663172e1b0a2942da23a99a1c24ec21b7492 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:53:10 +0000 Subject: [PATCH 358/518] forcing out=none for noncomplex device syrk --- src/serinv/block_primitive/syherk.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index bba6295c..d7323207 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -31,7 +31,7 @@ def syherk(a, c=None, alpha=1.0, beta=0.0, trans=0, lower = False): ModuleNotFoundError("Unknown Module") def matmul_syherk_host(a, c=None, alpha=1.0, beta=1.0, trans=0, lower=False, - overwrite_c=False, check_finite=True, side=0): + overwrite_c=False, check_finite=True,): """Computes out = alpha * op(a) @ op(a)^T + beta * b op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', @@ -129,6 +129,9 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals return out else: raise TypeError('invalid dtype') + + # Testing remove later + out=None trans = _trans_to_cublas_op(trans) if trans == cublas.CUBLAS_OP_N: From 5e27f76271d7a29828fe8d7233b516fdd63f1168 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 08:59:27 +0000 Subject: [PATCH 359/518] changed syher behavior in pobtas for testing --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index f4cc712b..21f2aed0 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -194,11 +194,11 @@ def _pobtaf( # ) # ) - A_arrow_tip_block[:, :] = ( + A_arrow_tip_block[:, :] -= ( syherk( L_lower_arrow_blocks[-1, :, :], A_arrow_tip_block[:, :], - alpha=-1.0, beta=1.0, lower=True + alpha=-1.0, beta=0.0, lower=True ) ) From f05a71411f8a4b484e70f4fbd81f4cd6416f832a Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 09:00:09 +0000 Subject: [PATCH 360/518] further change --- src/serinv/algs/pobtaf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 21f2aed0..3ecd360a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -197,7 +197,6 @@ def _pobtaf( A_arrow_tip_block[:, :] -= ( syherk( L_lower_arrow_blocks[-1, :, :], - A_arrow_tip_block[:, :], alpha=-1.0, beta=0.0, lower=True ) ) From 22380267392d91dcd98c9da76eea0657e7c8b0da Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 09:02:54 +0000 Subject: [PATCH 361/518] reverted pobtaf changes --- src/serinv/algs/pobtaf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 3ecd360a..f4cc712b 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -194,10 +194,11 @@ def _pobtaf( # ) # ) - A_arrow_tip_block[:, :] -= ( + A_arrow_tip_block[:, :] = ( syherk( L_lower_arrow_blocks[-1, :, :], - alpha=-1.0, beta=0.0, lower=True + A_arrow_tip_block[:, :], + alpha=-1.0, beta=1.0, lower=True ) ) From efcfea0ee96149a6f0787659539d0b92a28b67c7 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 09:06:06 +0000 Subject: [PATCH 362/518] changes to test syherk --- src/serinv/algs/pobtaf.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index f4cc712b..f017b48c 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -193,13 +193,13 @@ def _pobtaf( # trans_b='C', alpha=-1.0, beta=1.0 # ) # ) - - A_arrow_tip_block[:, :] = ( - syherk( + c = syherk( L_lower_arrow_blocks[-1, :, :], - A_arrow_tip_block[:, :], - alpha=-1.0, beta=1.0, lower=True + alpha=-1.0, beta=0.0, lower=True ) + + A_arrow_tip_block[:, :] -= ( + c ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) From 01730d5a28d1dbd4b17f0d177500ac8368db12bd Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 09:06:53 +0000 Subject: [PATCH 363/518] more debug changes --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index f017b48c..fc3689f3 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -198,8 +198,8 @@ def _pobtaf( alpha=-1.0, beta=0.0, lower=True ) - A_arrow_tip_block[:, :] -= ( - c + A_arrow_tip_block[:, :] = ( + A_arrow_tip_block[:, :]-c ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) From 7a5abe292d72401113054048178448bd339a21d5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 09:09:10 +0000 Subject: [PATCH 364/518] removed forced out --- src/serinv/algs/pobtaf.py | 10 +++++----- src/serinv/block_primitive/syherk.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index fc3689f3..f4cc712b 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -193,13 +193,13 @@ def _pobtaf( # trans_b='C', alpha=-1.0, beta=1.0 # ) # ) - c = syherk( - L_lower_arrow_blocks[-1, :, :], - alpha=-1.0, beta=0.0, lower=True - ) A_arrow_tip_block[:, :] = ( - A_arrow_tip_block[:, :]-c + syherk( + L_lower_arrow_blocks[-1, :, :], + A_arrow_tip_block[:, :], + alpha=-1.0, beta=1.0, lower=True + ) ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index d7323207..7fe1d557 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -131,7 +131,7 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals raise TypeError('invalid dtype') # Testing remove later - out=None + #out=None trans = _trans_to_cublas_op(trans) if trans == cublas.CUBLAS_OP_N: From c3d9355350e2a489f65288669733f06af65cc0bc Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 09:13:32 +0000 Subject: [PATCH 365/518] reverted to gemm for testing --- src/serinv/algs/pobtaf.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index f4cc712b..08219137 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -185,22 +185,24 @@ def _pobtaf( ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} - # A_arrow_tip_block[:, :] = ( - # gemm( - # L_lower_arrow_blocks[-1, :, :], - # L_lower_arrow_blocks[-1, :, :], - # A_arrow_tip_block[:, :], - # trans_b='C', alpha=-1.0, beta=1.0 - # ) - # ) - A_arrow_tip_block[:, :] = ( - syherk( + gemm( + L_lower_arrow_blocks[-1, :, :], L_lower_arrow_blocks[-1, :, :], A_arrow_tip_block[:, :], - alpha=-1.0, beta=1.0, lower=True + trans_b='C', alpha=-1.0, beta=1.0 ) ) + print(A_arrow_tip_block[:, :]) + raise ValueError("TEST") + + #A_arrow_tip_block[:, :] = ( + # syherk( + # L_lower_arrow_blocks[-1, :, :], + # A_arrow_tip_block[:, :], + # alpha=-1.0, beta=1.0, lower=True + # ) + #) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) From 3c1715adcd1cdd4605e6a8256fb26c772196baf9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 09:31:45 +0000 Subject: [PATCH 366/518] removed raised error --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 08219137..166e603c 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -194,7 +194,7 @@ def _pobtaf( ) ) print(A_arrow_tip_block[:, :]) - raise ValueError("TEST") + #raise ValueError("TEST") #A_arrow_tip_block[:, :] = ( # syherk( From 49ec7462ca5581f2daebc8492e4e53ef1f864864 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 09:47:05 +0000 Subject: [PATCH 367/518] changed back to syherk --- src/serinv/algs/pobtaf.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 166e603c..3e5ed358 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -185,24 +185,24 @@ def _pobtaf( ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} - A_arrow_tip_block[:, :] = ( - gemm( - L_lower_arrow_blocks[-1, :, :], - L_lower_arrow_blocks[-1, :, :], - A_arrow_tip_block[:, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) - ) - print(A_arrow_tip_block[:, :]) - #raise ValueError("TEST") - #A_arrow_tip_block[:, :] = ( - # syherk( + # gemm( + # L_lower_arrow_blocks[-1, :, :], # L_lower_arrow_blocks[-1, :, :], # A_arrow_tip_block[:, :], - # alpha=-1.0, beta=1.0, lower=True + # trans_b='C', alpha=-1.0, beta=1.0 # ) #) + print(A_arrow_tip_block[:, :]) + #raise ValueError("TEST") + + A_arrow_tip_block[:, :] = ( + syherk( + L_lower_arrow_blocks[-1, :, :], + A_arrow_tip_block[:, :], + alpha=-1.0, beta=1.0, lower=True + ) + ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) From c9fdcd09fc482b37d30316b7dbab00a61e245f9e Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 09:52:22 +0000 Subject: [PATCH 368/518] more prints --- src/serinv/algs/pobtaf.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 3e5ed358..79cefcc4 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -195,6 +195,9 @@ def _pobtaf( #) print(A_arrow_tip_block[:, :]) #raise ValueError("TEST") + + print(L_lower_arrow_blocks[-1, :, :]) + A_arrow_tip_block[:, :] = ( syherk( @@ -203,6 +206,7 @@ def _pobtaf( alpha=-1.0, beta=1.0, lower=True ) ) + print(A_arrow_tip_block[:, :]) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) From 7958743af3056c5589853e83f2c7b7d272eeffe8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 09:53:38 +0000 Subject: [PATCH 369/518] reverted to gemm --- src/serinv/algs/pobtaf.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 79cefcc4..9d1745b7 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -185,16 +185,16 @@ def _pobtaf( ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} - #A_arrow_tip_block[:, :] = ( - # gemm( - # L_lower_arrow_blocks[-1, :, :], - # L_lower_arrow_blocks[-1, :, :], - # A_arrow_tip_block[:, :], - # trans_b='C', alpha=-1.0, beta=1.0 - # ) - #) + A_arrow_tip_block[:, :] = ( + gemm( + L_lower_arrow_blocks[-1, :, :], + L_lower_arrow_blocks[-1, :, :], + A_arrow_tip_block[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) + ) print(A_arrow_tip_block[:, :]) - #raise ValueError("TEST") + raise ValueError("TEST") print(L_lower_arrow_blocks[-1, :, :]) From 6d4d8173ffc625370ea9dbaf74a753fb6980bdf1 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 10:09:31 +0000 Subject: [PATCH 370/518] trying gemm again --- src/serinv/algs/pobtaf.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 9d1745b7..a62a64d0 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -194,18 +194,18 @@ def _pobtaf( ) ) print(A_arrow_tip_block[:, :]) - raise ValueError("TEST") + #raise ValueError("TEST") print(L_lower_arrow_blocks[-1, :, :]) - A_arrow_tip_block[:, :] = ( - syherk( - L_lower_arrow_blocks[-1, :, :], - A_arrow_tip_block[:, :], - alpha=-1.0, beta=1.0, lower=True - ) - ) + #A_arrow_tip_block[:, :] = ( + # syherk( + # L_lower_arrow_blocks[-1, :, :], + # A_arrow_tip_block[:, :], + # alpha=-1.0, beta=1.0, lower=True + # ) + #) print(A_arrow_tip_block[:, :]) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) From 705de55f29d115c3696f0302fb791127d06d33ab Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 10:10:36 +0000 Subject: [PATCH 371/518] gemm works, returning to syrk --- src/serinv/algs/pobtaf.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index a62a64d0..79cefcc4 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -185,27 +185,27 @@ def _pobtaf( ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} - A_arrow_tip_block[:, :] = ( - gemm( - L_lower_arrow_blocks[-1, :, :], - L_lower_arrow_blocks[-1, :, :], - A_arrow_tip_block[:, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) - ) + #A_arrow_tip_block[:, :] = ( + # gemm( + # L_lower_arrow_blocks[-1, :, :], + # L_lower_arrow_blocks[-1, :, :], + # A_arrow_tip_block[:, :], + # trans_b='C', alpha=-1.0, beta=1.0 + # ) + #) print(A_arrow_tip_block[:, :]) #raise ValueError("TEST") print(L_lower_arrow_blocks[-1, :, :]) - #A_arrow_tip_block[:, :] = ( - # syherk( - # L_lower_arrow_blocks[-1, :, :], - # A_arrow_tip_block[:, :], - # alpha=-1.0, beta=1.0, lower=True - # ) - #) + A_arrow_tip_block[:, :] = ( + syherk( + L_lower_arrow_blocks[-1, :, :], + A_arrow_tip_block[:, :], + alpha=-1.0, beta=1.0, lower=True + ) + ) print(A_arrow_tip_block[:, :]) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) From 69811e2870c1e44f443a9ec9a67db24154b1f9ac Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 10:21:18 +0000 Subject: [PATCH 372/518] added print --- src/serinv/algs/pobtaf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 79cefcc4..b5d62e27 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -196,7 +196,9 @@ def _pobtaf( print(A_arrow_tip_block[:, :]) #raise ValueError("TEST") - print(L_lower_arrow_blocks[-1, :, :]) + print(L_lower_arrow_blocks[-1, :, :]-gemm(L_lower_arrow_blocks[-1, :, :], + L_lower_arrow_blocks[-1, :, :], + trans_b='C', alpha=-1.0)) A_arrow_tip_block[:, :] = ( From c483ccce861473ec77c014e3b0ae3cb7561f8843 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 10:22:02 +0000 Subject: [PATCH 373/518] changed print --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b5d62e27..645ceaaa 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -196,7 +196,7 @@ def _pobtaf( print(A_arrow_tip_block[:, :]) #raise ValueError("TEST") - print(L_lower_arrow_blocks[-1, :, :]-gemm(L_lower_arrow_blocks[-1, :, :], + print(A_arrow_tip_block[:, :]-gemm(L_lower_arrow_blocks[-1, :, :], L_lower_arrow_blocks[-1, :, :], trans_b='C', alpha=-1.0)) From a32f0b1919da3edbe9d0b6d09981472f853cbd23 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 10:22:53 +0000 Subject: [PATCH 374/518] fixed print --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 645ceaaa..57534085 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -198,7 +198,7 @@ def _pobtaf( print(A_arrow_tip_block[:, :]-gemm(L_lower_arrow_blocks[-1, :, :], L_lower_arrow_blocks[-1, :, :], - trans_b='C', alpha=-1.0)) + trans_b='C', alpha=1.0)) A_arrow_tip_block[:, :] = ( From 1ae99f95bf3f47fd84ca6444df765b8e88b053ac Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 10:24:08 +0000 Subject: [PATCH 375/518] sanity check --- src/serinv/algs/pobtaf.py | 1 - src/serinv/block_primitive/syherk.py | 12 ++++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 57534085..f7974e16 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -193,7 +193,6 @@ def _pobtaf( # trans_b='C', alpha=-1.0, beta=1.0 # ) #) - print(A_arrow_tip_block[:, :]) #raise ValueError("TEST") print(A_arrow_tip_block[:, :]-gemm(L_lower_arrow_blocks[-1, :, :], diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 7fe1d557..22b00268 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -169,11 +169,11 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals lda, trans = _decide_ld_and_trans(a, trans) ldo, _ = _decide_ld_and_trans(out, trans) - print(a) - print(out) - print(alpha) - print(beta) - print(lower) + #print(a) + #print(out) + #print(alpha) + #print(beta) + #print(lower) if out._c_contiguous: if not a._c_contiguous: @@ -205,5 +205,5 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals cublas.setPointerMode(handle, orig_mode) if not out._f_contiguous: out[...] = c - print(out) + #print(out) return out \ No newline at end of file From b1b583850f8bf98d1076399bd39cc361a8241acf Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 10:26:20 +0000 Subject: [PATCH 376/518] more sanity --- src/serinv/algs/pobtaf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index f7974e16..980bc6c3 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -199,6 +199,7 @@ def _pobtaf( L_lower_arrow_blocks[-1, :, :], trans_b='C', alpha=1.0)) + print(A_arrow_tip_block[:, :]-syherk(L_lower_arrow_blocks[-1, :, :],lower=True)) A_arrow_tip_block[:, :] = ( syherk( From 15deadfe535aedd24f2b028b0008b7bd082b9315 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 10:27:30 +0000 Subject: [PATCH 377/518] sanity 1 --- src/serinv/algs/pobtaf.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 980bc6c3..7ab8239b 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -199,15 +199,16 @@ def _pobtaf( L_lower_arrow_blocks[-1, :, :], trans_b='C', alpha=1.0)) - print(A_arrow_tip_block[:, :]-syherk(L_lower_arrow_blocks[-1, :, :],lower=True)) + A_arrow_tip_block[:, :]-=syherk(L_lower_arrow_blocks[-1, :, :],lower=True) + + #A_arrow_tip_block[:, :] = ( + # syherk( + # L_lower_arrow_blocks[-1, :, :], + # A_arrow_tip_block[:, :], + # alpha=-1.0, beta=1.0, lower=True + # ) + #) - A_arrow_tip_block[:, :] = ( - syherk( - L_lower_arrow_blocks[-1, :, :], - A_arrow_tip_block[:, :], - alpha=-1.0, beta=1.0, lower=True - ) - ) print(A_arrow_tip_block[:, :]) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) From 8ff51d0b31f10d89550898f13eec1c27b71b057a Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 10:28:40 +0000 Subject: [PATCH 378/518] sanity 2 --- src/serinv/algs/pobtaf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 7ab8239b..b89aa4b1 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -195,11 +195,11 @@ def _pobtaf( #) #raise ValueError("TEST") - print(A_arrow_tip_block[:, :]-gemm(L_lower_arrow_blocks[-1, :, :], + A_arrow_tip_block[:, :]-=gemm(L_lower_arrow_blocks[-1, :, :], L_lower_arrow_blocks[-1, :, :], - trans_b='C', alpha=1.0)) + trans_b='C', alpha=1.0) - A_arrow_tip_block[:, :]-=syherk(L_lower_arrow_blocks[-1, :, :],lower=True) + print(A_arrow_tip_block[:, :]-syherk(L_lower_arrow_blocks[-1, :, :],lower=True)) #A_arrow_tip_block[:, :] = ( # syherk( @@ -208,7 +208,7 @@ def _pobtaf( # alpha=-1.0, beta=1.0, lower=True # ) #) - + print(A_arrow_tip_block[:, :]) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) From ab106dfe49ebf08697bef412211cb5b205e6f782 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 10:30:01 +0000 Subject: [PATCH 379/518] sanity 3 --- src/serinv/algs/pobtaf.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b89aa4b1..5ad1f714 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -195,11 +195,11 @@ def _pobtaf( #) #raise ValueError("TEST") - A_arrow_tip_block[:, :]-=gemm(L_lower_arrow_blocks[-1, :, :], + print(A_arrow_tip_block[:, :]-gemm(L_lower_arrow_blocks[-1, :, :], L_lower_arrow_blocks[-1, :, :], - trans_b='C', alpha=1.0) + trans_b='C', alpha=1.0)) - print(A_arrow_tip_block[:, :]-syherk(L_lower_arrow_blocks[-1, :, :],lower=True)) + A_arrow_tip_block[:, :]-=syherk(L_lower_arrow_blocks[-1, :, :],lower=True)) #A_arrow_tip_block[:, :] = ( # syherk( @@ -209,11 +209,12 @@ def _pobtaf( # ) #) - print(A_arrow_tip_block[:, :]) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) + print(L_arrow_tip_block[:, :]) + def _pobtaf_permuted( A_diagonal_blocks: ArrayLike, From 5435a223344b3e9ded64865e4defce9a931b07a0 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 10:30:34 +0000 Subject: [PATCH 380/518] fixed parenthesis --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 5ad1f714..cab0e793 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -199,7 +199,7 @@ def _pobtaf( L_lower_arrow_blocks[-1, :, :], trans_b='C', alpha=1.0)) - A_arrow_tip_block[:, :]-=syherk(L_lower_arrow_blocks[-1, :, :],lower=True)) + A_arrow_tip_block[:, :]-=syherk(L_lower_arrow_blocks[-1, :, :],lower=True) #A_arrow_tip_block[:, :] = ( # syherk( From 870c1f95c4cf4ded5155c89ecd503171118d0ee2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 10:32:04 +0000 Subject: [PATCH 381/518] changed lower to upper --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index cab0e793..f930d7df 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -199,7 +199,7 @@ def _pobtaf( L_lower_arrow_blocks[-1, :, :], trans_b='C', alpha=1.0)) - A_arrow_tip_block[:, :]-=syherk(L_lower_arrow_blocks[-1, :, :],lower=True) + A_arrow_tip_block[:, :]-=syherk(L_lower_arrow_blocks[-1, :, :],lower=False) #A_arrow_tip_block[:, :] = ( # syherk( From f79c21ac3b08f773df97070ff2b1a8be95603e08 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 10:35:53 +0000 Subject: [PATCH 382/518] swap lower on device to match cholesky. hopefully temporary --- src/serinv/algs/pobtaf.py | 20 +++++++------------- src/serinv/block_primitive/syherk.py | 1 + 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index f930d7df..f12f5a66 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -194,20 +194,14 @@ def _pobtaf( # ) #) #raise ValueError("TEST") - - print(A_arrow_tip_block[:, :]-gemm(L_lower_arrow_blocks[-1, :, :], - L_lower_arrow_blocks[-1, :, :], - trans_b='C', alpha=1.0)) - A_arrow_tip_block[:, :]-=syherk(L_lower_arrow_blocks[-1, :, :],lower=False) - - #A_arrow_tip_block[:, :] = ( - # syherk( - # L_lower_arrow_blocks[-1, :, :], - # A_arrow_tip_block[:, :], - # alpha=-1.0, beta=1.0, lower=True - # ) - #) + A_arrow_tip_block[:, :] = ( + syherk( + L_lower_arrow_blocks[-1, :, :], + A_arrow_tip_block[:, :], + alpha=-1.0, beta=1.0, lower=True + ) + ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 22b00268..2e28aba9 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -26,6 +26,7 @@ def syherk(a, c=None, alpha=1.0, beta=0.0, trans=0, lower = False): if xp == np: return matmul_syherk_host(a, c, alpha, beta, trans, lower) elif xp == cp: + lower = not lower return matmul_syherk_device(a, trans, c, alpha, beta, lower) else: ModuleNotFoundError("Unknown Module") From 55548d0e80a80d0cd97860a1f430d5c76973d58d Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 10:54:09 +0000 Subject: [PATCH 383/518] random commit --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index f12f5a66..7a315c64 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -206,7 +206,7 @@ def _pobtaf( # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) - + print("LOL") print(L_arrow_tip_block[:, :]) From 030124f90ef386d4dd5d3c3305b1321fa7d29e95 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 11:01:01 +0000 Subject: [PATCH 384/518] updated 2. syherk --- src/serinv/algs/pobtaf.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 7a315c64..d3cd6063 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -161,11 +161,10 @@ def _pobtaf( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block[:, :] = ( - gemm( - L_lower_arrow_blocks[i, :, :], + syherk( L_lower_arrow_blocks[i, :, :], A_arrow_tip_block[:, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True ) ) @@ -185,16 +184,6 @@ def _pobtaf( ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} - #A_arrow_tip_block[:, :] = ( - # gemm( - # L_lower_arrow_blocks[-1, :, :], - # L_lower_arrow_blocks[-1, :, :], - # A_arrow_tip_block[:, :], - # trans_b='C', alpha=-1.0, beta=1.0 - # ) - #) - #raise ValueError("TEST") - A_arrow_tip_block[:, :] = ( syherk( L_lower_arrow_blocks[-1, :, :], @@ -206,8 +195,6 @@ def _pobtaf( # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) - print("LOL") - print(L_arrow_tip_block[:, :]) def _pobtaf_permuted( From 7e66166216d92e3cb644b450344e4604638a3ba7 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 11:01:44 +0000 Subject: [PATCH 385/518] third syherk --- src/serinv/algs/pobtaf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index d3cd6063..68809de1 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -142,11 +142,10 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - gemm( - L_lower_diagonal_blocks[i, :, :], + syherk( L_lower_diagonal_blocks[i, :, :], A_diagonal_blocks[i + 1, :, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T From 8cbda4bb101150e7e2f4ad22bcc82088163ab618 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 11:09:41 +0000 Subject: [PATCH 386/518] removed messages --- src/serinv/block_primitive/syherk.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 2e28aba9..21248484 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -170,12 +170,6 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals lda, trans = _decide_ld_and_trans(a, trans) ldo, _ = _decide_ld_and_trans(out, trans) - #print(a) - #print(out) - #print(alpha) - #print(beta) - #print(lower) - if out._c_contiguous: if not a._c_contiguous: a = a.copy(order='C') @@ -185,7 +179,6 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals func(handle, 1 - uplo, trans, n, k, alpha_ptr, a.data.ptr, lda, beta_ptr, out.data.ptr, ldo) - print("yes1") finally: cublas.setPointerMode(handle, orig_mode) @@ -201,10 +194,8 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals func(handle, uplo, trans, n, k, alpha_ptr, a.data.ptr, lda, beta_ptr, out.data.ptr, ldo) - print("yes2") finally: cublas.setPointerMode(handle, orig_mode) if not out._f_contiguous: out[...] = c - #print(out) return out \ No newline at end of file From cd5dd3535e1caf6d7757f87f2acff3bede2436c0 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 11:56:43 +0000 Subject: [PATCH 387/518] exposed lower param on cholesky for cupy --- src/serinv/__init__.py | 2 +- src/serinv/algs/pobtaf.py | 2 +- src/serinv/block_primitive/syherk.py | 1 - src/serinv/cupyfix/__init__.py | 4 ++-- src/serinv/cupyfix/cholesky_lowerfill.py | 9 +++++++-- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/serinv/__init__.py b/src/serinv/__init__.py index bad9c860..a8003816 100644 --- a/src/serinv/__init__.py +++ b/src/serinv/__init__.py @@ -25,7 +25,7 @@ # In the case of CuPy, we want to use the lowerfill version # tweaked in serinv. (More performances) - from serinv.cupyfix.cholesky_lowerfill import cholesky_lowerfill as cu_cholesky + from serinv.cupyfix.cholesky_lowerfill import cholesky as cu_cholesky # Check if cupy is actually working. This could still raise # a cudaErrorInsufficientDriver error or something. diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 68809de1..431b6057 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -193,7 +193,7 @@ def _pobtaf( # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) - L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) + L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=True) def _pobtaf_permuted( diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 21248484..e89bafba 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -26,7 +26,6 @@ def syherk(a, c=None, alpha=1.0, beta=0.0, trans=0, lower = False): if xp == np: return matmul_syherk_host(a, c, alpha, beta, trans, lower) elif xp == cp: - lower = not lower return matmul_syherk_device(a, trans, c, alpha, beta, lower) else: ModuleNotFoundError("Unknown Module") diff --git a/src/serinv/cupyfix/__init__.py b/src/serinv/cupyfix/__init__.py index 60f6f426..37015a0c 100644 --- a/src/serinv/cupyfix/__init__.py +++ b/src/serinv/cupyfix/__init__.py @@ -1,7 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. -from serinv.cupyfix.cholesky_lowerfill import cholesky_lowerfill +from serinv.cupyfix.cholesky_lowerfill import cholesky __all__ = [ - "cholesky_lowerfill", + "cholesky", ] diff --git a/src/serinv/cupyfix/cholesky_lowerfill.py b/src/serinv/cupyfix/cholesky_lowerfill.py index e4778c35..a5d053e0 100644 --- a/src/serinv/cupyfix/cholesky_lowerfill.py +++ b/src/serinv/cupyfix/cholesky_lowerfill.py @@ -7,7 +7,7 @@ from cupy.linalg import _util -def cholesky_lowerfill(a: cupy.ndarray) -> cupy.ndarray: +def cholesky(a: cupy.ndarray, lower=True) -> cupy.ndarray: """Cholesky decomposition. Decompose a given two-dimensional square matrix into ``L * L.H``, @@ -49,6 +49,11 @@ def cholesky_lowerfill(a: cupy.ndarray) -> cupy.ndarray: handle = device.get_cusolver_handle() dev_info = cupy.empty(1, dtype=numpy.int32) + if lower: + lower = cublas.CUBLAS_FILL_MODE_LOWER + else: + lower = cublas.CUBLAS_FILL_MODE_UPPER + if dtype == "f": potrf = cusolver.spotrf potrf_bufferSize = cusolver.spotrf_bufferSize @@ -68,7 +73,7 @@ def cholesky_lowerfill(a: cupy.ndarray) -> cupy.ndarray: workspace = cupy.empty(buffersize, dtype=dtype) potrf( handle, - cublas.CUBLAS_FILL_MODE_LOWER, + lower, n, x.data.ptr, n, From 9395fc8dc300c1c9fef94fa848dc4a048f0f43d6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 11:57:40 +0000 Subject: [PATCH 388/518] switched lower --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 431b6057..d85832e6 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -193,7 +193,7 @@ def _pobtaf( # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) - L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=True) + L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=False) def _pobtaf_permuted( From 9d64ee9d74434f49c722babd8933f789f34d42d9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 12:00:06 +0000 Subject: [PATCH 389/518] removed lower param in pobtaf from cholesky --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index d85832e6..68809de1 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -193,7 +193,7 @@ def _pobtaf( # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) - L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=False) + L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) def _pobtaf_permuted( From 5fc8601098ea87e9db648157d384f790aab6bd0b Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 12:03:27 +0000 Subject: [PATCH 390/518] added lower param again --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 68809de1..d85832e6 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -193,7 +193,7 @@ def _pobtaf( # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) - L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) + L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=False) def _pobtaf_permuted( From d73e2217d93f9411684acfa196010c968add15b1 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 12:07:48 +0000 Subject: [PATCH 391/518] transposing arrow tip block --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index d85832e6..c38972c2 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -187,13 +187,13 @@ def _pobtaf( syherk( L_lower_arrow_blocks[-1, :, :], A_arrow_tip_block[:, :], - alpha=-1.0, beta=1.0, lower=True + alpha=-1.0, beta=1.0, lower=True, trans='C' ) ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) - L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=False) + L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=True) def _pobtaf_permuted( From 862cd3bc1bf777e0c008f080d59e04bfac0ed638 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 12:08:37 +0000 Subject: [PATCH 392/518] trans T --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index c38972c2..ed0753d4 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -187,7 +187,7 @@ def _pobtaf( syherk( L_lower_arrow_blocks[-1, :, :], A_arrow_tip_block[:, :], - alpha=-1.0, beta=1.0, lower=True, trans='C' + alpha=-1.0, beta=1.0, lower=True, trans='N' ) ) From 41899f710a0511b3711afd475fcdf90d3d9311f5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 12:14:42 +0000 Subject: [PATCH 393/518] added one new cu_chol --- src/serinv/algs/pobtaf.py | 2 +- src/serinv/block_primitive/syherk.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index ed0753d4..2d37e11d 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -187,7 +187,7 @@ def _pobtaf( syherk( L_lower_arrow_blocks[-1, :, :], A_arrow_tip_block[:, :], - alpha=-1.0, beta=1.0, lower=True, trans='N' + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index e89bafba..d709f4b4 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -16,7 +16,7 @@ except (ImportError, ImportWarning, ModuleNotFoundError): pass -def syherk(a, c=None, alpha=1.0, beta=0.0, trans=0, lower = False): +def syherk(a, c=None, alpha=1.0, beta=0.0, trans=0, lower=False, cu_chol=False): """Wrapper for the trsm function to call depending on wheter the solve happens on the host or the device For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept @@ -26,7 +26,7 @@ def syherk(a, c=None, alpha=1.0, beta=0.0, trans=0, lower = False): if xp == np: return matmul_syherk_host(a, c, alpha, beta, trans, lower) elif xp == cp: - return matmul_syherk_device(a, trans, c, alpha, beta, lower) + return matmul_syherk_device(a, trans, c, alpha, beta, lower, cu_chol) else: ModuleNotFoundError("Unknown Module") @@ -101,7 +101,7 @@ def _get_scalar_ptr(a, dtype): return a, a_ptr # Util functions for cupy gemm end -def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=False): +def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=False, cu_chol=False): """Computes out := alpha*op1(a)*op2(a) + beta*out op1(a) = a if trans is 'N', op2(a) = a.T if transa is 'N' @@ -130,8 +130,9 @@ def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=Fals else: raise TypeError('invalid dtype') - # Testing remove later - #out=None + # If this is run in combination with cholesky, it will be necessary to flip lower + if cu_chol: + lower = not lower trans = _trans_to_cublas_op(trans) if trans == cublas.CUBLAS_OP_N: From c5923433ff3c47c2a18f60b88148496b1a1eea3f Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 12:16:48 +0000 Subject: [PATCH 394/518] added cu_chol to all syherk --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 2d37e11d..d1d0b813 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -145,7 +145,7 @@ def _pobtaf( syherk( L_lower_diagonal_blocks[i, :, :], A_diagonal_blocks[i + 1, :, :], - alpha=-1.0, beta=1.0, lower=True + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T @@ -163,7 +163,7 @@ def _pobtaf( syherk( L_lower_arrow_blocks[i, :, :], A_arrow_tip_block[:, :], - alpha=-1.0, beta=1.0, lower=True + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) From 9fa975dcfb38d21b2f747e13168f1d32fda38af9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 12:17:42 +0000 Subject: [PATCH 395/518] added lower to second cholesky --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index d1d0b813..b84ee25d 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -169,7 +169,7 @@ def _pobtaf( if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) - L_diagonal_blocks[-1, :, :] = cholesky(A_diagonal_blocks[-1, :, :]) + L_diagonal_blocks[-1, :, :] = cholesky(A_diagonal_blocks[-1, :, :], lower=True) # L_{ndb+1, ndb} = A_{ndb+1, ndb} @ L_{ndb, ndb}^{-T} L_lower_arrow_blocks[-1, :, :] = ( From 87f23ee1beafddb3b4ccb0ea863a4ee9eb23f918 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 12:19:16 +0000 Subject: [PATCH 396/518] further cholesky lower --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b84ee25d..085166b4 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -115,7 +115,7 @@ def _pobtaf( # Forward block-Cholesky for i in range(0, n_diag_blocks - 1): # L_{i, i} = chol(A_{i, i}) - L_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :]) + L_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :], lower=True) # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( From 1f90c2ef6766469c4aca10b462b67f9944623176 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 13:40:21 +0000 Subject: [PATCH 397/518] first permuted syherk --- src/serinv/algs/pobtaf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 085166b4..7be5e982 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -255,11 +255,10 @@ def _pobtaf_permuted( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - gemm( - L_lower_diagonal_blocks[i, :, :], + syherk( L_lower_diagonal_blocks[i, :, :], A_diagonal_blocks[i + 1, :, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T From 1e9eefee48afb90640c168c35b955fc9c3a5957d Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 13:41:05 +0000 Subject: [PATCH 398/518] 2 syherk permuted --- src/serinv/algs/pobtaf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 7be5e982..4266da81 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -274,11 +274,10 @@ def _pobtaf_permuted( # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T L_arrow_tip_block[:, :] = ( - gemm( - L_lower_arrow_blocks[i, :, :], + syherk( L_lower_arrow_blocks[i, :, :], L_arrow_tip_block[:, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) From 9219872b6698917b0085537931067e989ea2026a Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 13:41:39 +0000 Subject: [PATCH 399/518] syherk 3 oermuted --- src/serinv/algs/pobtaf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 4266da81..1a633e04 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -284,11 +284,10 @@ def _pobtaf_permuted( # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - gemm( - buffer[i, :, :], + syherk( buffer[i, :, :], A_diagonal_blocks[0, :, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) From 9e802a56aa12924ecead0070c9989abfb8d63ca0 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 13:42:57 +0000 Subject: [PATCH 400/518] syherk streaming 1 --- src/serinv/algs/pobtaf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 1a633e04..af62e564 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -475,11 +475,10 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - gemm( - L_lower_diagonal_blocks_d[i % 2, :, :], + syherk( L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) From 9e09f8b8b0e18bcc4eb6dc39d3d5705de5355f2c Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 13:43:37 +0000 Subject: [PATCH 401/518] syherk streaming 2 --- src/serinv/algs/pobtaf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index af62e564..58aee287 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -495,11 +495,10 @@ def _pobtaf_streaming( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - gemm( - L_lower_arrow_blocks_d[i % 2, :, :], + syherk( L_lower_arrow_blocks_d[i % 2, :, :], A_arrow_tip_block_d[:, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) compute_arrow_h2d_events[i % 2].record(stream=compute_stream) From 04bb683444164a34610b5c523205f22569091d7f Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 13:44:35 +0000 Subject: [PATCH 402/518] syherk streaming 3 --- src/serinv/algs/pobtaf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 58aee287..2b5ed0c4 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -545,11 +545,10 @@ def _pobtaf_streaming( if factorize_last_block: # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block_d[:, :] = ( - gemm( - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], + syherk( L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], A_arrow_tip_block_d[:, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) From 3b78056a40808e6e62cf4ae0e7b2ed077b9fca24 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 13:47:33 +0000 Subject: [PATCH 403/518] 2 permuted streaming syherk --- src/serinv/algs/pobtaf.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 2b5ed0c4..f9ee74c8 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -768,11 +768,10 @@ def _pobtaf_permuted_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - gemm( - L_lower_diagonal_blocks_d[i % 2, :, :], + syherk( L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) @@ -799,11 +798,10 @@ def _pobtaf_permuted_streaming( # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - gemm( - L_lower_arrow_blocks_d[i % 2, :, :], + syherk( L_lower_arrow_blocks_d[i % 2, :, :], A_arrow_tip_block_d[:, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) From 4a9c3ff2d68da3ff986e2b4cae8ed7d444e95bc2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Thu, 12 Jun 2025 13:48:13 +0000 Subject: [PATCH 404/518] all syherk done --- src/serinv/algs/pobtaf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index f9ee74c8..b19d486e 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -820,11 +820,10 @@ def _pobtaf_permuted_streaming( # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_top_block_d[:, :] = ( - gemm( - L_upper_nested_dissection_buffer_d[i % 2, :, :], + syherk( L_upper_nested_dissection_buffer_d[i % 2, :, :], A_diagonal_top_block_d[:, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) From 30cfd09a935327fb1218930d32314108fa23bc46 Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 12:59:52 +0000 Subject: [PATCH 405/518] test if L can be ommited from factorize last block --- src/serinv/algs/pobtaf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b19d486e..b207b2e0 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -169,10 +169,10 @@ def _pobtaf( if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) - L_diagonal_blocks[-1, :, :] = cholesky(A_diagonal_blocks[-1, :, :], lower=True) + A_diagonal_blocks[-1, :, :] = cholesky(A_diagonal_blocks[-1, :, :], lower=True) # L_{ndb+1, ndb} = A_{ndb+1, ndb} @ L_{ndb, ndb}^{-T} - L_lower_arrow_blocks[-1, :, :] = ( + A_lower_arrow_blocks[-1, :, :] = ( trsm( L_diagonal_blocks[-1, :, :], A_lower_arrow_blocks[-1, :, :].conj().T, @@ -193,7 +193,7 @@ def _pobtaf( # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) - L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=True) + A_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=True) def _pobtaf_permuted( From bec3cba11bb2e11e3fd74c93119cdcc498dad9c3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 13:13:49 +0000 Subject: [PATCH 406/518] reverted the L's for now --- src/serinv/algs/pobtaf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b207b2e0..b19d486e 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -169,10 +169,10 @@ def _pobtaf( if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) - A_diagonal_blocks[-1, :, :] = cholesky(A_diagonal_blocks[-1, :, :], lower=True) + L_diagonal_blocks[-1, :, :] = cholesky(A_diagonal_blocks[-1, :, :], lower=True) # L_{ndb+1, ndb} = A_{ndb+1, ndb} @ L_{ndb, ndb}^{-T} - A_lower_arrow_blocks[-1, :, :] = ( + L_lower_arrow_blocks[-1, :, :] = ( trsm( L_diagonal_blocks[-1, :, :], A_lower_arrow_blocks[-1, :, :].conj().T, @@ -193,7 +193,7 @@ def _pobtaf( # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) - A_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=True) + L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=True) def _pobtaf_permuted( From 1da239691172cfaa222b2e207baf6b77086c282f Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 13:18:39 +0000 Subject: [PATCH 407/518] inserted print test to check A and L after cholesky --- src/serinv/algs/pobtaf.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b19d486e..5f724015 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -116,7 +116,11 @@ def _pobtaf( for i in range(0, n_diag_blocks - 1): # L_{i, i} = chol(A_{i, i}) L_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :], lower=True) + print(L_diagonal_blocks[i, :, :]) + print("#") + print(A_diagonal_blocks[i, :, :]) + raise ValueError("TEST") # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( trsm( From cfb2cac3fdc9b917c9d4bce46a59cc230dbf9712 Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 13:20:18 +0000 Subject: [PATCH 408/518] moved the print test to see ifsomething changed --- src/serinv/algs/pobtaf.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 5f724015..44ace27f 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -116,11 +116,7 @@ def _pobtaf( for i in range(0, n_diag_blocks - 1): # L_{i, i} = chol(A_{i, i}) L_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :], lower=True) - print(L_diagonal_blocks[i, :, :]) - print("#") - print(A_diagonal_blocks[i, :, :]) - - raise ValueError("TEST") + # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( trsm( @@ -171,6 +167,12 @@ def _pobtaf( ) ) + print(L_diagonal_blocks[i, :, :]) + print("#") + print(A_diagonal_blocks[i, :, :]) + + raise ValueError("TEST") + if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) L_diagonal_blocks[-1, :, :] = cholesky(A_diagonal_blocks[-1, :, :], lower=True) From 413aec5bb99e28d8557b39c5a502f81507393043 Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 13:21:43 +0000 Subject: [PATCH 409/518] switcht break condition --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 44ace27f..5e5129a5 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -171,7 +171,7 @@ def _pobtaf( print("#") print(A_diagonal_blocks[i, :, :]) - raise ValueError("TEST") + raise ValueError("TEST") if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) From 7df1dc14ff2be26170d5a6bc9d804f00c5e27708 Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 13:23:38 +0000 Subject: [PATCH 410/518] switched L diag block to A diag block to see if it works --- src/serinv/algs/pobtaf.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 5e5129a5..19ad652a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -115,12 +115,12 @@ def _pobtaf( # Forward block-Cholesky for i in range(0, n_diag_blocks - 1): # L_{i, i} = chol(A_{i, i}) - L_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :], lower=True) + A_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :], lower=True) # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( trsm( - L_diagonal_blocks[i, :, :], + A_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :], trans='C',lower=True, side=1 ) @@ -131,7 +131,7 @@ def _pobtaf( # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( trsm( - L_diagonal_blocks[i, :, :], + A_diagonal_blocks[i, :, :], A_lower_arrow_blocks[i, :, :].conj().T, lower=True, ) @@ -167,11 +167,6 @@ def _pobtaf( ) ) - print(L_diagonal_blocks[i, :, :]) - print("#") - print(A_diagonal_blocks[i, :, :]) - - raise ValueError("TEST") if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) From 97e057a45641846b4d8fdecc71c0c4a0fec6be11 Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 13:25:36 +0000 Subject: [PATCH 411/518] trying to switch a few more L's --- src/serinv/algs/pobtaf.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 19ad652a..425c3cc3 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -118,7 +118,7 @@ def _pobtaf( A_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :], lower=True) # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} - L_lower_diagonal_blocks[i, :, :] = ( + A_lower_diagonal_blocks[i, :, :] = ( trsm( A_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :], @@ -129,7 +129,7 @@ def _pobtaf( # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} - L_lower_arrow_blocks[i, :, :] = ( + A_lower_arrow_blocks[i, :, :] = ( trsm( A_diagonal_blocks[i, :, :], A_lower_arrow_blocks[i, :, :].conj().T, @@ -143,7 +143,7 @@ def _pobtaf( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( syherk( - L_lower_diagonal_blocks[i, :, :], + A_lower_diagonal_blocks[i, :, :], A_diagonal_blocks[i + 1, :, :], alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) @@ -151,8 +151,8 @@ def _pobtaf( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( gemm( - L_lower_arrow_blocks[i, :, :], - L_lower_diagonal_blocks[i, :, :], + A_lower_arrow_blocks[i, :, :], + A_lower_diagonal_blocks[i, :, :], A_lower_arrow_blocks[i + 1, :, :], trans_b='C', alpha=-1.0, beta=1.0 ) @@ -161,7 +161,7 @@ def _pobtaf( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block[:, :] = ( syherk( - L_lower_arrow_blocks[i, :, :], + A_lower_arrow_blocks[i, :, :], A_arrow_tip_block[:, :], alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) @@ -170,12 +170,12 @@ def _pobtaf( if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) - L_diagonal_blocks[-1, :, :] = cholesky(A_diagonal_blocks[-1, :, :], lower=True) + A_diagonal_blocks[-1, :, :] = cholesky(A_diagonal_blocks[-1, :, :], lower=True) # L_{ndb+1, ndb} = A_{ndb+1, ndb} @ L_{ndb, ndb}^{-T} L_lower_arrow_blocks[-1, :, :] = ( trsm( - L_diagonal_blocks[-1, :, :], + A_diagonal_blocks[-1, :, :], A_lower_arrow_blocks[-1, :, :].conj().T, lower=True, ) From 01d0caab7ca07bc23f008a027f375e39ed0c23e2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 13:26:11 +0000 Subject: [PATCH 412/518] removed all the L's --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 425c3cc3..41a64167 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -173,7 +173,7 @@ def _pobtaf( A_diagonal_blocks[-1, :, :] = cholesky(A_diagonal_blocks[-1, :, :], lower=True) # L_{ndb+1, ndb} = A_{ndb+1, ndb} @ L_{ndb, ndb}^{-T} - L_lower_arrow_blocks[-1, :, :] = ( + A_lower_arrow_blocks[-1, :, :] = ( trsm( A_diagonal_blocks[-1, :, :], A_lower_arrow_blocks[-1, :, :].conj().T, @@ -194,7 +194,7 @@ def _pobtaf( # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) - L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=True) + A_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=True) def _pobtaf_permuted( From 9120fff03a3633450df14791eee01b0a06150d40 Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 13:26:34 +0000 Subject: [PATCH 413/518] actually removed all the L's --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 41a64167..a3ed155c 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -186,7 +186,7 @@ def _pobtaf( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block[:, :] = ( syherk( - L_lower_arrow_blocks[-1, :, :], + A_lower_arrow_blocks[-1, :, :], A_arrow_tip_block[:, :], alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) From 2e0ca8b2abdd71ff482d1d789f20950d3d4f9106 Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 14:01:53 +0000 Subject: [PATCH 414/518] reverted the L's back in --- src/serinv/algs/pobtaf.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index a3ed155c..b19d486e 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -115,12 +115,12 @@ def _pobtaf( # Forward block-Cholesky for i in range(0, n_diag_blocks - 1): # L_{i, i} = chol(A_{i, i}) - A_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :], lower=True) - + L_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :], lower=True) + # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} - A_lower_diagonal_blocks[i, :, :] = ( + L_lower_diagonal_blocks[i, :, :] = ( trsm( - A_diagonal_blocks[i, :, :], + L_diagonal_blocks[i, :, :], A_lower_diagonal_blocks[i, :, :], trans='C',lower=True, side=1 ) @@ -129,9 +129,9 @@ def _pobtaf( # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} - A_lower_arrow_blocks[i, :, :] = ( + L_lower_arrow_blocks[i, :, :] = ( trsm( - A_diagonal_blocks[i, :, :], + L_diagonal_blocks[i, :, :], A_lower_arrow_blocks[i, :, :].conj().T, lower=True, ) @@ -143,7 +143,7 @@ def _pobtaf( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( syherk( - A_lower_diagonal_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], A_diagonal_blocks[i + 1, :, :], alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) @@ -151,8 +151,8 @@ def _pobtaf( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( gemm( - A_lower_arrow_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :], + L_lower_arrow_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], A_lower_arrow_blocks[i + 1, :, :], trans_b='C', alpha=-1.0, beta=1.0 ) @@ -161,21 +161,20 @@ def _pobtaf( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block[:, :] = ( syherk( - A_lower_arrow_blocks[i, :, :], + L_lower_arrow_blocks[i, :, :], A_arrow_tip_block[:, :], alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) - if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) - A_diagonal_blocks[-1, :, :] = cholesky(A_diagonal_blocks[-1, :, :], lower=True) + L_diagonal_blocks[-1, :, :] = cholesky(A_diagonal_blocks[-1, :, :], lower=True) # L_{ndb+1, ndb} = A_{ndb+1, ndb} @ L_{ndb, ndb}^{-T} - A_lower_arrow_blocks[-1, :, :] = ( + L_lower_arrow_blocks[-1, :, :] = ( trsm( - A_diagonal_blocks[-1, :, :], + L_diagonal_blocks[-1, :, :], A_lower_arrow_blocks[-1, :, :].conj().T, lower=True, ) @@ -186,7 +185,7 @@ def _pobtaf( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block[:, :] = ( syherk( - A_lower_arrow_blocks[-1, :, :], + L_lower_arrow_blocks[-1, :, :], A_arrow_tip_block[:, :], alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) @@ -194,7 +193,7 @@ def _pobtaf( # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) - A_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=True) + L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=True) def _pobtaf_permuted( From cb4b4d5c5e75130a931728eb694174a6d0fce966 Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 14:04:42 +0000 Subject: [PATCH 415/518] trying to put chol in place --- src/serinv/cupyfix/cholesky_lowerfill.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/serinv/cupyfix/cholesky_lowerfill.py b/src/serinv/cupyfix/cholesky_lowerfill.py index a5d053e0..db0d4b9f 100644 --- a/src/serinv/cupyfix/cholesky_lowerfill.py +++ b/src/serinv/cupyfix/cholesky_lowerfill.py @@ -44,8 +44,9 @@ def cholesky(a: cupy.ndarray, lower=True) -> cupy.ndarray: if a.size == 0: return cupy.empty(a.shape, out_dtype) - x = a.astype(dtype, order="C", copy=False) n = len(a) + a = a.astype(dtype, order="C", copy=False) + handle = device.get_cusolver_handle() dev_info = cupy.empty(1, dtype=numpy.int32) @@ -68,14 +69,14 @@ def cholesky(a: cupy.ndarray, lower=True) -> cupy.ndarray: potrf_bufferSize = cusolver.zpotrf_bufferSize buffersize = potrf_bufferSize( - handle, cublas.CUBLAS_FILL_MODE_LOWER, n, x.data.ptr, n + handle, cublas.CUBLAS_FILL_MODE_LOWER, n, a.data.ptr, n ) workspace = cupy.empty(buffersize, dtype=dtype) potrf( handle, lower, n, - x.data.ptr, + a.data.ptr, n, workspace.data.ptr, buffersize, @@ -85,6 +86,6 @@ def cholesky(a: cupy.ndarray, lower=True) -> cupy.ndarray: potrf, dev_info ) - _util._triu(x, k=0) - cupy.conjugate(x, out=x) - return x.astype(out_dtype, copy=False).T + _util._triu(a, k=0) + cupy.conjugate(a, out=a) + return a.astype(out_dtype, copy=False).T From 940407a48c8d45711e052f09b998490b8cdf28be Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 14:19:32 +0000 Subject: [PATCH 416/518] reverted cholesky_lowerfill because it changed nothing --- src/serinv/cupyfix/cholesky_lowerfill.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/serinv/cupyfix/cholesky_lowerfill.py b/src/serinv/cupyfix/cholesky_lowerfill.py index db0d4b9f..a5d053e0 100644 --- a/src/serinv/cupyfix/cholesky_lowerfill.py +++ b/src/serinv/cupyfix/cholesky_lowerfill.py @@ -44,9 +44,8 @@ def cholesky(a: cupy.ndarray, lower=True) -> cupy.ndarray: if a.size == 0: return cupy.empty(a.shape, out_dtype) + x = a.astype(dtype, order="C", copy=False) n = len(a) - a = a.astype(dtype, order="C", copy=False) - handle = device.get_cusolver_handle() dev_info = cupy.empty(1, dtype=numpy.int32) @@ -69,14 +68,14 @@ def cholesky(a: cupy.ndarray, lower=True) -> cupy.ndarray: potrf_bufferSize = cusolver.zpotrf_bufferSize buffersize = potrf_bufferSize( - handle, cublas.CUBLAS_FILL_MODE_LOWER, n, a.data.ptr, n + handle, cublas.CUBLAS_FILL_MODE_LOWER, n, x.data.ptr, n ) workspace = cupy.empty(buffersize, dtype=dtype) potrf( handle, lower, n, - a.data.ptr, + x.data.ptr, n, workspace.data.ptr, buffersize, @@ -86,6 +85,6 @@ def cholesky(a: cupy.ndarray, lower=True) -> cupy.ndarray: potrf, dev_info ) - _util._triu(a, k=0) - cupy.conjugate(a, out=a) - return a.astype(out_dtype, copy=False).T + _util._triu(x, k=0) + cupy.conjugate(x, out=x) + return x.astype(out_dtype, copy=False).T From 54585917baa21ce24879f0cdeadcb6b234e81129 Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 14:54:44 +0000 Subject: [PATCH 417/518] switched override b to true in trsm --- src/serinv/block_primitive/trsm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/block_primitive/trsm.py b/src/serinv/block_primitive/trsm.py index 0185bed9..0c236350 100644 --- a/src/serinv/block_primitive/trsm.py +++ b/src/serinv/block_primitive/trsm.py @@ -15,7 +15,7 @@ pass def trsm(a, b, trans=0, lower = False, unit_diagonal=False, - overwrite_b=False, check_finite=False, side=0): + overwrite_b=True, check_finite=False, side=0): """Wrapper for the trsm function to call depending on wheter the solve happens on the host or the device For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept From 144cf8a08c47bab5fc2a74ab374751b5235032ed Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 15:04:54 +0000 Subject: [PATCH 418/518] added nvtx for testing --- src/serinv/cupyfix/cholesky_lowerfill.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/serinv/cupyfix/cholesky_lowerfill.py b/src/serinv/cupyfix/cholesky_lowerfill.py index a5d053e0..b09b4160 100644 --- a/src/serinv/cupyfix/cholesky_lowerfill.py +++ b/src/serinv/cupyfix/cholesky_lowerfill.py @@ -6,6 +6,8 @@ from cupy.cuda import device from cupy.linalg import _util +from cupy.cuda.nvtx import RangePush, RangePop + def cholesky(a: cupy.ndarray, lower=True) -> cupy.ndarray: """Cholesky decomposition. @@ -34,6 +36,7 @@ def cholesky(a: cupy.ndarray, lower=True) -> cupy.ndarray: .. seealso:: :func:`numpy.linalg.cholesky` """ + RangePush("cholesky") from cupy_backends.cuda.libs import cublas, cusolver _util._assert_cupy_array(a) @@ -87,4 +90,5 @@ def cholesky(a: cupy.ndarray, lower=True) -> cupy.ndarray: _util._triu(x, k=0) cupy.conjugate(x, out=x) + RangePop() return x.astype(out_dtype, copy=False).T From 0315503ab197ae702e4707ff354c10e1bd9221f1 Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 16:02:55 +0000 Subject: [PATCH 419/518] more nvtx --- src/serinv/algs/pobtaf.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b19d486e..b2a219b3 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -9,6 +9,8 @@ from serinv.block_primitive import trsm, gemm, syherk +from cupy.cuda.nvtx import RangePush, RangePop + def pobtaf( A_diagonal_blocks: ArrayLike, A_lower_diagonal_blocks: ArrayLike, @@ -115,8 +117,9 @@ def _pobtaf( # Forward block-Cholesky for i in range(0, n_diag_blocks - 1): # L_{i, i} = chol(A_{i, i}) + RangePush("chol call") L_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :], lower=True) - + RangePop() # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( trsm( From 5743a68e0039e61d7ef55f9c767a172047a3bf19 Mon Sep 17 00:00:00 2001 From: 03szust Date: Mon, 16 Jun 2025 17:02:49 +0000 Subject: [PATCH 420/518] removed nvtx --- src/serinv/algs/pobtaf.py | 4 ---- src/serinv/cupyfix/cholesky_lowerfill.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b2a219b3..e47880c3 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -9,8 +9,6 @@ from serinv.block_primitive import trsm, gemm, syherk -from cupy.cuda.nvtx import RangePush, RangePop - def pobtaf( A_diagonal_blocks: ArrayLike, A_lower_diagonal_blocks: ArrayLike, @@ -117,9 +115,7 @@ def _pobtaf( # Forward block-Cholesky for i in range(0, n_diag_blocks - 1): # L_{i, i} = chol(A_{i, i}) - RangePush("chol call") L_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :], lower=True) - RangePop() # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( trsm( diff --git a/src/serinv/cupyfix/cholesky_lowerfill.py b/src/serinv/cupyfix/cholesky_lowerfill.py index b09b4160..a5d053e0 100644 --- a/src/serinv/cupyfix/cholesky_lowerfill.py +++ b/src/serinv/cupyfix/cholesky_lowerfill.py @@ -6,8 +6,6 @@ from cupy.cuda import device from cupy.linalg import _util -from cupy.cuda.nvtx import RangePush, RangePop - def cholesky(a: cupy.ndarray, lower=True) -> cupy.ndarray: """Cholesky decomposition. @@ -36,7 +34,6 @@ def cholesky(a: cupy.ndarray, lower=True) -> cupy.ndarray: .. seealso:: :func:`numpy.linalg.cholesky` """ - RangePush("cholesky") from cupy_backends.cuda.libs import cublas, cusolver _util._assert_cupy_array(a) @@ -90,5 +87,4 @@ def cholesky(a: cupy.ndarray, lower=True) -> cupy.ndarray: _util._triu(x, k=0) cupy.conjugate(x, out=x) - RangePop() return x.astype(out_dtype, copy=False).T From a2d821c447f2b1d8d243a1269eddfcdbcbe3d8e9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 07:54:29 +0000 Subject: [PATCH 421/518] added streaming tests to pobtaf --- tests/tests_algs/regular/tests_bta/test_pobtaf.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtaf.py b/tests/tests_algs/regular/tests_bta/test_pobtaf.py index a30b9094..877933d6 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtaf.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtaf.py @@ -3,11 +3,20 @@ import numpy as np import pytest +from ....conftest import ARRAY_TYPE as ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bta_dense_to_arrays, dd_bta, symmetrize from serinv.algs import pobtaf +if backend_flags["cupy_avail"]: + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + if backend_flags["cupy_avail"]: import cupyx as cpx From 9d300b96fb41e5bf0b3e4240aade7de0d3358a8a Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 07:59:12 +0000 Subject: [PATCH 422/518] check if streaming is happening --- tests/tests_algs/regular/tests_bta/test_pobtaf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtaf.py b/tests/tests_algs/regular/tests_bta/test_pobtaf.py index 877933d6..eb4fdd7b 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtaf.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtaf.py @@ -66,6 +66,7 @@ def test_pobtaf( A_lower_diagonal_blocks = A_lower_diagonal_blocks_pinned A_lower_arrow_blocks = A_lower_arrow_blocks_pinned A_arrow_tip_block = A_arrow_tip_block_pinned + raise ValueError("Streaming") pobtaf( A_diagonal_blocks, From d722ad866cbd13ab1f9a31287eb0ed202e01d03b Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:01:37 +0000 Subject: [PATCH 423/518] added more testing code --- tests/tests_algs/regular/tests_bta/test_pobtaf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtaf.py b/tests/tests_algs/regular/tests_bta/test_pobtaf.py index eb4fdd7b..981ff670 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtaf.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtaf.py @@ -20,6 +20,9 @@ if backend_flags["cupy_avail"]: import cupyx as cpx +@pytest.fixture(params=ARRAY_TYPE, autouse=True) +def array_type(request: pytest.FixtureRequest) -> str: + return request.param @pytest.mark.mpi_skip() def test_pobtaf( @@ -66,7 +69,6 @@ def test_pobtaf( A_lower_diagonal_blocks = A_lower_diagonal_blocks_pinned A_lower_arrow_blocks = A_lower_arrow_blocks_pinned A_arrow_tip_block = A_arrow_tip_block_pinned - raise ValueError("Streaming") pobtaf( A_diagonal_blocks, From 4d7043d4d0e7920fbc1bbfc9976e7a77a7edba25 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:10:42 +0000 Subject: [PATCH 424/518] permuted pobtasi streaming --- .../permuted/test_bta/test_pobtasi_permuted.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py b/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py index dc6e3a25..e4b9fc40 100644 --- a/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py +++ b/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py @@ -3,15 +3,28 @@ import numpy as np import pytest +from ....conftest import ARRAY_TYPE as ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bta_dense_to_arrays, dd_bta, symmetrize from serinv.utils import allocate_pobtax_permutation_buffers from serinv.algs import pobtaf, pobtasi +if backend_flags["cupy_avail"]: + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + if backend_flags["cupy_avail"]: import cupyx as cpx +@pytest.fixture(params=ARRAY_TYPE, autouse=True) +def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() def test_pobtasi_permuted( From feaa90ec5b2a90d553230634d68d906c608d49db Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:15:52 +0000 Subject: [PATCH 425/518] removed improvements from pobtaf perm --- src/serinv/algs/pobtaf.py | 45 +++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index e47880c3..0db633e7 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -421,11 +421,13 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - trsm( + cu_la.solve_triangular( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :], - trans='C',lower=True, side=1 + A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, + lower=True, ) + .conj() + .T ) compute_lower_events[i % 2].record(stream=compute_stream) @@ -446,7 +448,7 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - trsm( + cu_la.solve_triangular( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -474,31 +476,24 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - syherk( - L_lower_diagonal_blocks_d[i % 2, :, :], - A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True - ) + A_diagonal_blocks_d[(i + 1) % 2, :, :] + - L_lower_diagonal_blocks_d[i % 2, :, :] + @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - gemm( - L_lower_arrow_blocks_d[i % 2, :, :], - L_lower_diagonal_blocks_d[i % 2, :, :], - A_lower_arrow_blocks_d[(i + 1) % 2, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_lower_arrow_blocks_d[(i + 1) % 2, :, :] + - L_lower_arrow_blocks_d[i % 2, :, :] + @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) compute_lower_h2d_events[i % 2].record(stream=compute_stream) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - syherk( - L_lower_arrow_blocks_d[i % 2, :, :], - A_arrow_tip_block_d[:, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True - ) + A_arrow_tip_block_d[:, :] + - L_lower_arrow_blocks_d[i % 2, :, :] + @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T ) compute_arrow_h2d_events[i % 2].record(stream=compute_stream) @@ -523,7 +518,7 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) if factorize_last_block: L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] = ( - trsm( + cu_la.solve_triangular( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T, lower=True, @@ -544,11 +539,9 @@ def _pobtaf_streaming( if factorize_last_block: # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block_d[:, :] = ( - syherk( - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], - A_arrow_tip_block_d[:, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True - ) + A_arrow_tip_block_d[:, :] + - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] + @ L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) From 6da2b8ab1a2e7f90985933baa1b66b1626835981 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:20:00 +0000 Subject: [PATCH 426/518] check first imp in perm pobtaf --- src/serinv/algs/pobtaf.py | 53 +++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 0db633e7..941ec8ba 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -421,13 +421,11 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) compute_lower_events[i % 2].record(stream=compute_stream) @@ -448,7 +446,7 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -476,24 +474,31 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + syherk( + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - A_lower_arrow_blocks_d[(i + 1) % 2, :, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_lower_arrow_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) compute_lower_h2d_events[i % 2].record(stream=compute_stream) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T + syherk( + L_lower_arrow_blocks_d[i % 2, :, :], + A_arrow_tip_block_d[:, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) compute_arrow_h2d_events[i % 2].record(stream=compute_stream) @@ -518,7 +523,7 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) if factorize_last_block: L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T, lower=True, @@ -539,9 +544,11 @@ def _pobtaf_streaming( if factorize_last_block: # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] - @ L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T + syherk( + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], + A_arrow_tip_block_d[:, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) @@ -684,11 +691,13 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - trsm( + cu_la.solve_triangular( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :], - trans='C',lower=True, side=1 + A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, + lower=True, ) + .conj() + .T ) cp_lower_events[i % 2].record(stream=compute_stream) From b7a541c64ae0fead9dddbffa89d2724f1cbb47b8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:20:37 +0000 Subject: [PATCH 427/518] second imp --- src/serinv/algs/pobtaf.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 941ec8ba..170aff5a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -691,13 +691,11 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) cp_lower_events[i % 2].record(stream=compute_stream) @@ -718,7 +716,7 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - trsm( + cu_la.solve_triangular( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, From f1f99dba16bfbc73c3781704d6fc5ac6855c08cf Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:21:09 +0000 Subject: [PATCH 428/518] third imp test --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 170aff5a..0954a2bb 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -716,7 +716,7 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -736,7 +736,7 @@ def _pobtaf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - trsm( + cu_la.solve_triangular( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, lower=True, From e229ec1dc5e44981bfd1eff985abaff1fa88dafa Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:21:53 +0000 Subject: [PATCH 429/518] 4 imp test --- src/serinv/algs/pobtaf.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 0954a2bb..0e544b67 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -736,7 +736,7 @@ def _pobtaf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, lower=True, @@ -767,11 +767,9 @@ def _pobtaf_permuted_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - syherk( - L_lower_diagonal_blocks_d[i % 2, :, :], - A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True - ) + A_diagonal_blocks_d[(i + 1) % 2, :, :] + - L_lower_diagonal_blocks_d[i % 2, :, :] + @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T From 963eb9f4e1cb9092929243d7768f161b67613af4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:22:21 +0000 Subject: [PATCH 430/518] 5 imp test --- src/serinv/algs/pobtaf.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 0e544b67..02f1babc 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -767,19 +767,18 @@ def _pobtaf_permuted_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + syherk( + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - gemm( - L_lower_arrow_blocks_d[i % 2, :, :], - L_lower_diagonal_blocks_d[i % 2, :, :], - A_lower_arrow_blocks_d[(i + 1) % 2, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_lower_arrow_blocks_d[(i + 1) % 2, :, :] + - L_lower_arrow_blocks_d[i % 2, :, :] + @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T From 66547e177831b0130e923461c99a87d45f5ce807 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:22:57 +0000 Subject: [PATCH 431/518] 6 imp test --- src/serinv/algs/pobtaf.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 02f1babc..40f599be 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -776,18 +776,18 @@ def _pobtaf_permuted_streaming( # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - A_lower_arrow_blocks_d[(i + 1) % 2, :, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_lower_arrow_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer_d[(i + 1) % 2, :, :] = ( - gemm( - L_upper_nested_dissection_buffer_d[i % 2, :, :], - L_lower_diagonal_blocks_d[i % 2, :, :], - trans_b='C', alpha=-1.0 - ) + -L_upper_nested_dissection_buffer_d[i % 2, :, :] + @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) cp_lower_events_h2d_release[i % 2].record(stream=compute_stream) From f6ac6835d4e049e9937c37cad334e7a81dbe945e Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:23:33 +0000 Subject: [PATCH 432/518] 7 imp test --- src/serinv/algs/pobtaf.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 40f599be..e5f07441 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -786,19 +786,20 @@ def _pobtaf_permuted_streaming( # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer_d[(i + 1) % 2, :, :] = ( - -L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + trans_b='C', alpha=-1.0 + ) ) cp_lower_events_h2d_release[i % 2].record(stream=compute_stream) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - syherk( - L_lower_arrow_blocks_d[i % 2, :, :], - A_arrow_tip_block_d[:, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True - ) + A_arrow_tip_block_d[:, :] + - L_lower_arrow_blocks_d[i % 2, :, :] + @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T ) # Update the top (first blocks) of the arrowhead From 1dbc0d9c33381158115daf716da2042e7c2315a5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:24:07 +0000 Subject: [PATCH 433/518] 8 imp test --- src/serinv/algs/pobtaf.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index e5f07441..5af1515e 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -797,20 +797,19 @@ def _pobtaf_permuted_streaming( # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T + syherk( + L_lower_arrow_blocks_d[i % 2, :, :], + A_arrow_tip_block_d[:, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_arrow_bottom_top_block_d[:, :] = ( - gemm( - L_lower_arrow_blocks_d[i % 2, :, :], - L_upper_nested_dissection_buffer_d[i % 2, :, :], - A_arrow_bottom_top_block_d[:, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_arrow_bottom_top_block_d[:, :] + - L_lower_arrow_blocks_d[i % 2, :, :] + @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T ) cp_arrow_events_h2d_release[i % 2].record(stream=compute_stream) From 537488b5edfb388fa74a19d40327b6da82c92f31 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:24:38 +0000 Subject: [PATCH 434/518] 9 imp test --- src/serinv/algs/pobtaf.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 5af1515e..af5b8f39 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -807,20 +807,21 @@ def _pobtaf_permuted_streaming( # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_arrow_bottom_top_block_d[:, :] = ( - A_arrow_bottom_top_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_upper_nested_dissection_buffer_d[i % 2, :, :], + A_arrow_bottom_top_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) cp_arrow_events_h2d_release[i % 2].record(stream=compute_stream) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_top_block_d[:, :] = ( - syherk( - L_upper_nested_dissection_buffer_d[i % 2, :, :], - A_diagonal_top_block_d[:, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True - ) + A_diagonal_top_block_d[:, :] + - L_upper_nested_dissection_buffer_d[i % 2, :, :] + @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T ) # --- Device 2 Host transfers --- From e1e5583e1deebe8a9bd29fc9ebe163e25125ab1b Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:25:37 +0000 Subject: [PATCH 435/518] check full --- src/serinv/algs/pobtaf.py | 53 ++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 32 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index af5b8f39..ff5087f2 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -691,11 +691,13 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - trsm( + cu_la.solve_triangular( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :], - trans='C',lower=True, side=1 + A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, + lower=True, ) + .conj() + .T ) cp_lower_events[i % 2].record(stream=compute_stream) @@ -716,7 +718,7 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - trsm( + cu_la.solve_triangular( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -736,7 +738,7 @@ def _pobtaf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - trsm( + cu_la.solve_triangular( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, lower=True, @@ -767,52 +769,39 @@ def _pobtaf_permuted_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - syherk( - L_lower_diagonal_blocks_d[i % 2, :, :], - A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True - ) + A_diagonal_blocks_d[(i + 1) % 2, :, :] + - L_lower_diagonal_blocks_d[i % 2, :, :] + @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - gemm( - L_lower_arrow_blocks_d[i % 2, :, :], - L_lower_diagonal_blocks_d[i % 2, :, :], - A_lower_arrow_blocks_d[(i + 1) % 2, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_lower_arrow_blocks_d[(i + 1) % 2, :, :] + - L_lower_arrow_blocks_d[i % 2, :, :] + @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer_d[(i + 1) % 2, :, :] = ( - gemm( - L_upper_nested_dissection_buffer_d[i % 2, :, :], - L_lower_diagonal_blocks_d[i % 2, :, :], - trans_b='C', alpha=-1.0 - ) + -L_upper_nested_dissection_buffer_d[i % 2, :, :] + @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) cp_lower_events_h2d_release[i % 2].record(stream=compute_stream) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - syherk( - L_lower_arrow_blocks_d[i % 2, :, :], - A_arrow_tip_block_d[:, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True - ) + A_arrow_tip_block_d[:, :] + - L_lower_arrow_blocks_d[i % 2, :, :] + @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_arrow_bottom_top_block_d[:, :] = ( - gemm( - L_lower_arrow_blocks_d[i % 2, :, :], - L_upper_nested_dissection_buffer_d[i % 2, :, :], - A_arrow_bottom_top_block_d[:, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_arrow_bottom_top_block_d[:, :] + - L_lower_arrow_blocks_d[i % 2, :, :] + @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T ) cp_arrow_events_h2d_release[i % 2].record(stream=compute_stream) From 8efe2dab812d2242d062db109ca995ee3233097d Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:28:06 +0000 Subject: [PATCH 436/518] removed all improvemnts from pobtaf permuted for sanity, check now sole improvemnts to find error --- src/serinv/algs/pobtaf.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index ff5087f2..a8faa8f2 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -691,13 +691,11 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) cp_lower_events[i % 2].record(stream=compute_stream) From dea6e7009b6d4f53331cd82f72208713dc95940d Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:28:46 +0000 Subject: [PATCH 437/518] 2 imp add --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index a8faa8f2..e4bbe951 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -716,7 +716,7 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, From 00f2e3108977b52cb131738602997d70c025331f Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:29:24 +0000 Subject: [PATCH 438/518] 3 imp add --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index e4bbe951..e8bd8910 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -736,7 +736,7 @@ def _pobtaf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, lower=True, From 4bce6bcc41df30fb4e9da01789e1bc39eebc9d67 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:30:10 +0000 Subject: [PATCH 439/518] 4 imp add --- src/serinv/algs/pobtaf.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index e8bd8910..6b3cf594 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -767,9 +767,11 @@ def _pobtaf_permuted_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + syherk( + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T From d414da3c776c0bd09be477b6e7a0dfdc19832ef4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:31:13 +0000 Subject: [PATCH 440/518] fix attempt first error --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 6b3cf594..12cc25c1 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -770,7 +770,7 @@ def _pobtaf_permuted_streaming( syherk( L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True + alpha=-1.0, beta=1.0, lower=True, cu_chol=False ) ) From 3a1ad6e6a3c46378bf9428525608fe0957630285 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:36:09 +0000 Subject: [PATCH 441/518] print added to check for error --- src/serinv/algs/pobtaf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 12cc25c1..f3e83af2 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -770,9 +770,10 @@ def _pobtaf_permuted_streaming( syherk( L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=False + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) + print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( From e0d7bf7f4d24a1f269003204ceba67fc08ce4426 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:38:46 +0000 Subject: [PATCH 442/518] added test for error --- src/serinv/algs/pobtaf.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index f3e83af2..ca34a454 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -767,13 +767,12 @@ def _pobtaf_permuted_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - syherk( - L_lower_diagonal_blocks_d[i % 2, :, :], - A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True - ) + A_diagonal_blocks_d[(i + 1) % 2, :, :] + - L_lower_diagonal_blocks_d[i % 2, :, :] + @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) + raise ValueError("TEST") # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( From 837135221775037cd54277b01a7dab8e7a84c44e Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:44:02 +0000 Subject: [PATCH 443/518] switched lower --- src/serinv/algs/pobtaf.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index ca34a454..6cc5a9ca 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -767,24 +767,32 @@ def _pobtaf_permuted_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + syherk( + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + alpha=-1.0, beta=1.0, lower=False, cu_chol=True + ) ) print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) - raise ValueError("TEST") + #raise ValueError("TEST") # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - A_lower_arrow_blocks_d[(i + 1) % 2, :, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_lower_arrow_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer_d[(i + 1) % 2, :, :] = ( - -L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + trans_b='C', alpha=-1.0 + ) ) cp_lower_events_h2d_release[i % 2].record(stream=compute_stream) From e54697d6e1cb7c257f450716bcea6c06154c2502 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:46:27 +0000 Subject: [PATCH 444/518] added print --- src/serinv/algs/pobtaf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 6cc5a9ca..be8cf4a6 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -766,11 +766,12 @@ def _pobtaf_permuted_streaming( # Update next diagonal block compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T + print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( syherk( L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=False, cu_chol=True + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) From 3738bf7bd79742273e948cf95ab1c1a3edb53133 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:47:41 +0000 Subject: [PATCH 445/518] reverted to nonimproved to see actual sol --- src/serinv/algs/pobtaf.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index be8cf4a6..1499968f 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -768,11 +768,9 @@ def _pobtaf_permuted_streaming( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - syherk( - L_lower_diagonal_blocks_d[i % 2, :, :], - A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True - ) + A_diagonal_blocks_d[(i + 1) % 2, :, :] + - L_lower_diagonal_blocks_d[i % 2, :, :] + @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) #raise ValueError("TEST") From b3a7e085c11617c02a92ff4e244648f1713bfda3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 08:48:00 +0000 Subject: [PATCH 446/518] added break --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 1499968f..cf45ba38 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -773,7 +773,7 @@ def _pobtaf_permuted_streaming( @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) - #raise ValueError("TEST") + raise ValueError("TEST") # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( From 78173383f5aaba9fe109dd5be53988f2e8244266 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 09:00:58 +0000 Subject: [PATCH 447/518] added lower to chol --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index cf45ba38..a32d34af 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -668,7 +668,7 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_diagonal_events[i % 2]) L_diagonal_blocks_d[i % 2, :, :] = cholesky( - A_diagonal_blocks_d[i % 2, :, :] + A_diagonal_blocks_d[i % 2, :, :], lower=True ) cp_diagonal_events[i % 2].record(stream=compute_stream) @@ -773,7 +773,7 @@ def _pobtaf_permuted_streaming( @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) - raise ValueError("TEST") + #raise ValueError("TEST") # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( From 95e34aabeb9dd10238aae91d1e433ef92a4aba6d Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 09:01:48 +0000 Subject: [PATCH 448/518] trying syherk again --- src/serinv/algs/pobtaf.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index a32d34af..1e3a7193 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -768,9 +768,11 @@ def _pobtaf_permuted_streaming( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + syherk( + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) #raise ValueError("TEST") From 709335e5ef9eecab1e335c7ba0ea92f2a8fc43c2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 09:04:09 +0000 Subject: [PATCH 449/518] print chol sol as well --- src/serinv/algs/pobtaf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 1e3a7193..20adc62e 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -671,6 +671,7 @@ def _pobtaf_permuted_streaming( A_diagonal_blocks_d[i % 2, :, :], lower=True ) cp_diagonal_events[i % 2].record(stream=compute_stream) + print(L_diagonal_blocks_d[i % 2, :, :]) d2h_stream.wait_event(cp_diagonal_events[i % 2]) L_diagonal_blocks_d[i % 2, :, :].get( From a568fa2f5112184735b01b8b49c22ef40942deb7 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 09:04:48 +0000 Subject: [PATCH 450/518] removed syherk --- src/serinv/algs/pobtaf.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 20adc62e..f8359e18 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -769,11 +769,9 @@ def _pobtaf_permuted_streaming( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - syherk( - L_lower_diagonal_blocks_d[i % 2, :, :], - A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True - ) + A_diagonal_blocks_d[(i + 1) % 2, :, :] + - L_lower_diagonal_blocks_d[i % 2, :, :] + @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) #raise ValueError("TEST") From daccf3f892467113496027ce2028145c75475b17 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 09:05:46 +0000 Subject: [PATCH 451/518] moved error --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index f8359e18..b493c338 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -774,7 +774,7 @@ def _pobtaf_permuted_streaming( @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) - #raise ValueError("TEST") + # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( @@ -820,7 +820,7 @@ def _pobtaf_permuted_streaming( - L_upper_nested_dissection_buffer_d[i % 2, :, :] @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T ) - + raise ValueError("TEST") # --- Device 2 Host transfers --- d2h_stream.wait_event(cp_lower_events_h2d_release[(n_diag_blocks - 2) % 2]) A_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :].get( From befa65b5ef1dc5ddf535e6eb785b06cd93faba6d Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 09:12:48 +0000 Subject: [PATCH 452/518] added syherk to test for test behaviour --- src/serinv/algs/pobtaf.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b493c338..f6c64a30 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -769,9 +769,11 @@ def _pobtaf_permuted_streaming( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + syherk( + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) @@ -820,7 +822,7 @@ def _pobtaf_permuted_streaming( - L_upper_nested_dissection_buffer_d[i % 2, :, :] @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T ) - raise ValueError("TEST") + # --- Device 2 Host transfers --- d2h_stream.wait_event(cp_lower_events_h2d_release[(n_diag_blocks - 2) % 2]) A_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :].get( From 252abef1eb06398eff07d8f9a7685fc716dc9ecc Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 09:14:10 +0000 Subject: [PATCH 453/518] added prints in test --- tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py b/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py index e4b9fc40..00550a17 100644 --- a/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py +++ b/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py @@ -137,6 +137,8 @@ def test_pobtasi_permuted( _A_arrow_tip_block, ) + print(X_arrow_tip_block_ref) + print(_A_arrow_tip_block) # Verify that the reduced system is already correct assert xp.allclose(X_arrow_tip_block_ref, _A_arrow_tip_block) assert xp.allclose(X_diagonal_blocks_ref[0], _A_diagonal_blocks[0]) From fc37075dd075847f3971f0ca4b75c9f404ae57dc Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 09:16:32 +0000 Subject: [PATCH 454/518] replaced syherk with gemm in permuted streaming because of a propagation error --- src/serinv/algs/pobtaf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index f6c64a30..91dc7c68 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -769,10 +769,11 @@ def _pobtaf_permuted_streaming( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - syherk( + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True + trans_b='C', alpha=-1.0, beta=1.0 ) ) print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) From e16d8038f1caadfde8f3fa11c1b9a1718af3d9c5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 09:18:22 +0000 Subject: [PATCH 455/518] added all syher withc cu_chol false --- src/serinv/algs/pobtaf.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 91dc7c68..9f619cd4 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -769,11 +769,10 @@ def _pobtaf_permuted_streaming( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - gemm( - L_lower_diagonal_blocks_d[i % 2, :, :], + syherk( L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=False ) ) print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) @@ -802,26 +801,33 @@ def _pobtaf_permuted_streaming( # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T + syherk( + L_lower_arrow_blocks_d[i % 2, :, :], + A_arrow_tip_block_d[:, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=False + ) ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_arrow_bottom_top_block_d[:, :] = ( - A_arrow_bottom_top_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_upper_nested_dissection_buffer_d[i % 2, :, :], + A_arrow_bottom_top_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) cp_arrow_events_h2d_release[i % 2].record(stream=compute_stream) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_top_block_d[:, :] = ( - A_diagonal_top_block_d[:, :] - - L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T + syherk( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + A_diagonal_top_block_d[:, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=False + ) ) # --- Device 2 Host transfers --- From f901ee22a419d66baa2e0f941b663657862a5dc6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 09:23:31 +0000 Subject: [PATCH 456/518] check if diagonal blocks are similar --- tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py b/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py index 00550a17..e811d360 100644 --- a/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py +++ b/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py @@ -137,8 +137,8 @@ def test_pobtasi_permuted( _A_arrow_tip_block, ) - print(X_arrow_tip_block_ref) - print(_A_arrow_tip_block) + print(X_diagonal_blocks_ref) + print(_A_diagonal_blocks) # Verify that the reduced system is already correct assert xp.allclose(X_arrow_tip_block_ref, _A_arrow_tip_block) assert xp.allclose(X_diagonal_blocks_ref[0], _A_diagonal_blocks[0]) From f0eacf06b558ece284a9d45d156d0ac14cb45ed9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 09:25:38 +0000 Subject: [PATCH 457/518] removed intermittend assert --- .../permuted/test_bta/test_pobtasi_permuted.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py b/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py index e811d360..8cabbae3 100644 --- a/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py +++ b/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py @@ -140,11 +140,11 @@ def test_pobtasi_permuted( print(X_diagonal_blocks_ref) print(_A_diagonal_blocks) # Verify that the reduced system is already correct - assert xp.allclose(X_arrow_tip_block_ref, _A_arrow_tip_block) - assert xp.allclose(X_diagonal_blocks_ref[0], _A_diagonal_blocks[0]) - assert xp.allclose(X_diagonal_blocks_ref[-1], _A_diagonal_blocks[-1]) - assert xp.allclose(X_lower_arrow_blocks_ref[0], _A_lower_arrow_blocks[0]) - assert xp.allclose(X_lower_arrow_blocks_ref[-1], _A_lower_arrow_blocks[-1]) + #assert xp.allclose(X_arrow_tip_block_ref, _A_arrow_tip_block) + #assert xp.allclose(X_diagonal_blocks_ref[0], _A_diagonal_blocks[0]) + #assert xp.allclose(X_diagonal_blocks_ref[-1], _A_diagonal_blocks[-1]) + #assert xp.allclose(X_lower_arrow_blocks_ref[0], _A_lower_arrow_blocks[0]) + #assert xp.allclose(X_lower_arrow_blocks_ref[-1], _A_lower_arrow_blocks[-1]) # Map back the correct reduced system to the original system A_diagonal_blocks[0] = _A_diagonal_blocks[0] From f8a09c1404d78255b3072bad8557fdfa85b2b8b9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 09:26:43 +0000 Subject: [PATCH 458/518] changed one syherk to gemm --- src/serinv/algs/pobtaf.py | 5 +++-- .../permuted/test_bta/test_pobtasi_permuted.py | 10 +++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 9f619cd4..bcc17627 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -769,10 +769,11 @@ def _pobtaf_permuted_streaming( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - syherk( + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=False + trans_b='C', alpha=-1.0, beta=1.0 ) ) print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) diff --git a/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py b/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py index 8cabbae3..e811d360 100644 --- a/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py +++ b/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py @@ -140,11 +140,11 @@ def test_pobtasi_permuted( print(X_diagonal_blocks_ref) print(_A_diagonal_blocks) # Verify that the reduced system is already correct - #assert xp.allclose(X_arrow_tip_block_ref, _A_arrow_tip_block) - #assert xp.allclose(X_diagonal_blocks_ref[0], _A_diagonal_blocks[0]) - #assert xp.allclose(X_diagonal_blocks_ref[-1], _A_diagonal_blocks[-1]) - #assert xp.allclose(X_lower_arrow_blocks_ref[0], _A_lower_arrow_blocks[0]) - #assert xp.allclose(X_lower_arrow_blocks_ref[-1], _A_lower_arrow_blocks[-1]) + assert xp.allclose(X_arrow_tip_block_ref, _A_arrow_tip_block) + assert xp.allclose(X_diagonal_blocks_ref[0], _A_diagonal_blocks[0]) + assert xp.allclose(X_diagonal_blocks_ref[-1], _A_diagonal_blocks[-1]) + assert xp.allclose(X_lower_arrow_blocks_ref[0], _A_lower_arrow_blocks[0]) + assert xp.allclose(X_lower_arrow_blocks_ref[-1], _A_lower_arrow_blocks[-1]) # Map back the correct reduced system to the original system A_diagonal_blocks[0] = _A_diagonal_blocks[0] From c5ae3ef62581dd0520432da2c38fc558591646b4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 09:30:28 +0000 Subject: [PATCH 459/518] added comment explaining weird behaviour --- src/serinv/algs/pobtaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index bcc17627..674ff796 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -767,8 +767,9 @@ def _pobtaf_permuted_streaming( # Update next diagonal block compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T - print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( + # gemm instead of syherk because this somehow kept failing tests in a very weird way + # probably because both sides of the diagonal matrix are used somwhere in a relevant way gemm( L_lower_diagonal_blocks_d[i % 2, :, :], L_lower_diagonal_blocks_d[i % 2, :, :], @@ -776,7 +777,6 @@ def _pobtaf_permuted_streaming( trans_b='C', alpha=-1.0, beta=1.0 ) ) - print(A_diagonal_blocks_d[(i + 1) % 2, :, :]) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T From 0ca95d6408e2f8fe4b23f955631c1dbf008df486 Mon Sep 17 00:00:00 2001 From: 03szust Date: Tue, 17 Jun 2025 09:42:30 +0000 Subject: [PATCH 460/518] fix for syherk issue in normal streaming --- src/serinv/algs/pobtaf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 674ff796..f9db298a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -474,10 +474,11 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - syherk( + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True + trans_b='C', alpha=-1.0, beta=1.0 ) ) From 83d12f96994de8c1c3b51e60d968a54eadf6d2b9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 08:07:33 +0000 Subject: [PATCH 461/518] added on pobtas trsm and trying to fix pobtaf streaming --- src/serinv/algs/pobtaf.py | 1 - src/serinv/algs/pobtas.py | 12 ++++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index f9db298a..4a1eed07 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -475,7 +475,6 @@ def _pobtaf_streaming( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( gemm( - L_lower_diagonal_blocks_d[i % 2, :, :], L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], trans_b='C', alpha=-1.0, beta=1.0 diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 953906ed..2dbc1fab 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -7,6 +7,7 @@ _get_module_from_str, ) +from serinv.block_primitive import trsm, gemm, syherk def pobtas( L_diagonal_blocks: ArrayLike, @@ -87,11 +88,14 @@ def _pobtas( if trans == "N": # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True + ) ) + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( L_lower_diagonal_blocks[i] From b88fb758f1046aec45c8955b309c0d701783b0ce Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 08:08:32 +0000 Subject: [PATCH 462/518] fixing the fix --- src/serinv/algs/pobtaf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 4a1eed07..f9db298a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -475,6 +475,7 @@ def _pobtaf_streaming( # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], trans_b='C', alpha=-1.0, beta=1.0 From 2ecb420afcd9b3a7145921fe33865f78a9939266 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 08:10:27 +0000 Subject: [PATCH 463/518] further syherk fix --- src/serinv/algs/pobtaf.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index f9db298a..09543f7c 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -474,11 +474,10 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - gemm( - L_lower_diagonal_blocks_d[i % 2, :, :], + syherk( L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=False ) ) @@ -498,7 +497,7 @@ def _pobtaf_streaming( syherk( L_lower_arrow_blocks_d[i % 2, :, :], A_arrow_tip_block_d[:, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True + alpha=-1.0, beta=1.0, lower=True, cu_chol=False ) ) compute_arrow_h2d_events[i % 2].record(stream=compute_stream) @@ -548,7 +547,7 @@ def _pobtaf_streaming( syherk( L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], A_arrow_tip_block_d[:, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True + alpha=-1.0, beta=1.0, lower=True, cu_chol=False ) ) From c5a19937328396965663ebd35bb434789b840a7b Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 08:11:14 +0000 Subject: [PATCH 464/518] more fixing --- src/serinv/algs/pobtaf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 09543f7c..17c62328 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -474,7 +474,8 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - syherk( + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], alpha=-1.0, beta=1.0, lower=True, cu_chol=False From 7982cf88501c43d007103530fd5ee4396361541b Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 08:11:46 +0000 Subject: [PATCH 465/518] fixed keywords --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 17c62328..7b0e7671 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -478,7 +478,7 @@ def _pobtaf_streaming( L_lower_diagonal_blocks_d[i % 2, :, :], L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=False + trans_b='C', alpha=-1.0, beta=1.0 ) ) From e349e21fe4b370f5cce3711111f563e9f877f070 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 08:12:24 +0000 Subject: [PATCH 466/518] next syherk fix --- src/serinv/algs/pobtaf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 7b0e7671..c910f26a 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -495,10 +495,11 @@ def _pobtaf_streaming( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - syherk( + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], L_lower_arrow_blocks_d[i % 2, :, :], A_arrow_tip_block_d[:, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=False + trans_b='C', alpha=-1.0, beta=1.0 ) ) compute_arrow_h2d_events[i % 2].record(stream=compute_stream) From c175b8e0890008cb6797b9b9e8c42dcd87ebd526 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 08:13:10 +0000 Subject: [PATCH 467/518] third syherk fix --- src/serinv/algs/pobtaf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index c910f26a..2a80e5d8 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -546,10 +546,11 @@ def _pobtaf_streaming( if factorize_last_block: # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block_d[:, :] = ( - syherk( + gemm( + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], A_arrow_tip_block_d[:, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=False + trans_b='C', alpha=-1.0, beta=1.0 ) ) From 33cbb9a8e7d751be78e4ee13f757c3a0ad7551b5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 08:15:54 +0000 Subject: [PATCH 468/518] replaced with working streaming code --- src/serinv/algs/pobtaf.py | 48 ++++++++++++++++----------------------- 1 file changed, 19 insertions(+), 29 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 2a80e5d8..b0876a64 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -421,11 +421,13 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - trsm( + cu_la.solve_triangular( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :], - trans='C',lower=True, side=1 + A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, + lower=True, ) + .conj() + .T ) compute_lower_events[i % 2].record(stream=compute_stream) @@ -446,7 +448,7 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - trsm( + cu_la.solve_triangular( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -474,33 +476,24 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - gemm( - L_lower_diagonal_blocks_d[i % 2, :, :], - L_lower_diagonal_blocks_d[i % 2, :, :], - A_diagonal_blocks_d[(i + 1) % 2, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_diagonal_blocks_d[(i + 1) % 2, :, :] + - L_lower_diagonal_blocks_d[i % 2, :, :] + @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - gemm( - L_lower_arrow_blocks_d[i % 2, :, :], - L_lower_diagonal_blocks_d[i % 2, :, :], - A_lower_arrow_blocks_d[(i + 1) % 2, :, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_lower_arrow_blocks_d[(i + 1) % 2, :, :] + - L_lower_arrow_blocks_d[i % 2, :, :] + @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T ) compute_lower_h2d_events[i % 2].record(stream=compute_stream) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - gemm( - L_lower_arrow_blocks_d[i % 2, :, :], - L_lower_arrow_blocks_d[i % 2, :, :], - A_arrow_tip_block_d[:, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_arrow_tip_block_d[:, :] + - L_lower_arrow_blocks_d[i % 2, :, :] + @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T ) compute_arrow_h2d_events[i % 2].record(stream=compute_stream) @@ -525,7 +518,7 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) if factorize_last_block: L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] = ( - trsm( + cu_la.solve_triangular( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T, lower=True, @@ -546,12 +539,9 @@ def _pobtaf_streaming( if factorize_last_block: # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block_d[:, :] = ( - gemm( - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], - A_arrow_tip_block_d[:, :], - trans_b='C', alpha=-1.0, beta=1.0 - ) + A_arrow_tip_block_d[:, :] + - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] + @ L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) From 118745e303a7f7fa98191a75bdd2af77112eb1f9 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 08:22:02 +0000 Subject: [PATCH 469/518] reset streaming code to original --- src/serinv/algs/pobtas.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 2dbc1fab..953906ed 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -7,7 +7,6 @@ _get_module_from_str, ) -from serinv.block_primitive import trsm, gemm, syherk def pobtas( L_diagonal_blocks: ArrayLike, @@ -88,14 +87,11 @@ def _pobtas( if trans == "N": # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( - trsm( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True - ) + B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True ) - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( L_lower_diagonal_blocks[i] From 591aca624558b7784276929735437d1c5a016fd5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 08:48:27 +0000 Subject: [PATCH 470/518] added streaming improvements back in --- src/serinv/algs/pobtaf.py | 48 +++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index b0876a64..2a80e5d8 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -421,13 +421,11 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) compute_lower_events[i % 2].record(stream=compute_stream) @@ -448,7 +446,7 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -476,24 +474,33 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - A_lower_arrow_blocks_d[(i + 1) % 2, :, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_lower_arrow_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) compute_lower_h2d_events[i % 2].record(stream=compute_stream) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_arrow_blocks_d[i % 2, :, :], + A_arrow_tip_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) compute_arrow_h2d_events[i % 2].record(stream=compute_stream) @@ -518,7 +525,7 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) if factorize_last_block: L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T, lower=True, @@ -539,9 +546,12 @@ def _pobtaf_streaming( if factorize_last_block: # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] - @ L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], + A_arrow_tip_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) From d1b357056f649de96344a8db2a58fd79b9360106 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 08:49:46 +0000 Subject: [PATCH 471/518] reverted to start of day to see if issue is local --- src/serinv/algs/pobtaf.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 2a80e5d8..674ff796 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -474,11 +474,10 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - gemm( - L_lower_diagonal_blocks_d[i % 2, :, :], + syherk( L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) @@ -495,11 +494,10 @@ def _pobtaf_streaming( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - gemm( - L_lower_arrow_blocks_d[i % 2, :, :], + syherk( L_lower_arrow_blocks_d[i % 2, :, :], A_arrow_tip_block_d[:, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) compute_arrow_h2d_events[i % 2].record(stream=compute_stream) @@ -546,11 +544,10 @@ def _pobtaf_streaming( if factorize_last_block: # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block_d[:, :] = ( - gemm( - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], + syherk( L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], A_arrow_tip_block_d[:, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) From a76e378c000586d6ddbe327bd6a859e8d5ab3639 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 08:51:07 +0000 Subject: [PATCH 472/518] check if issue with permuted streaming was local --- src/serinv/algs/pobtaf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 674ff796..6d1aab7d 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -770,11 +770,10 @@ def _pobtaf_permuted_streaming( A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( # gemm instead of syherk because this somehow kept failing tests in a very weird way # probably because both sides of the diagonal matrix are used somwhere in a relevant way - gemm( - L_lower_diagonal_blocks_d[i % 2, :, :], + syherk( L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - trans_b='C', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0, lower=True, cu_chol=False ) ) From 39b15c94dda5ed5aaebf1d0c55aa9a0db415d32a Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 08:51:39 +0000 Subject: [PATCH 473/518] further test --- src/serinv/algs/pobtaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 6d1aab7d..24a68e39 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -773,7 +773,7 @@ def _pobtaf_permuted_streaming( syherk( L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=False + alpha=-1.0, beta=1.0, lower=True, cu_chol=True ) ) From eac28229c34cf394074411ba9c621e7df89e1768 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 08:52:32 +0000 Subject: [PATCH 474/518] issue wasn't local, reverted to gemm --- src/serinv/algs/pobtaf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 24a68e39..674ff796 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -770,10 +770,11 @@ def _pobtaf_permuted_streaming( A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( # gemm instead of syherk because this somehow kept failing tests in a very weird way # probably because both sides of the diagonal matrix are used somwhere in a relevant way - syherk( + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], L_lower_diagonal_blocks_d[i % 2, :, :], A_diagonal_blocks_d[(i + 1) % 2, :, :], - alpha=-1.0, beta=1.0, lower=True, cu_chol=True + trans_b='C', alpha=-1.0, beta=1.0 ) ) From 1c88f06c66aebf63fcdb289ccf42dc6f549bbc97 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:02:08 +0000 Subject: [PATCH 475/518] first two improvemnts in pobtas --- src/serinv/algs/pobtas.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 953906ed..1c03935f 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -7,6 +7,7 @@ _get_module_from_str, ) +from serinv.block_primitive import trsm, gemm, syherk def pobtas( L_diagonal_blocks: ArrayLike, @@ -87,15 +88,21 @@ def _pobtas( if trans == "N": # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True + ) ) - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( - L_lower_diagonal_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] + alpha=-1.0, beta=1.0 + ) ) B[-arrow_blocksize:] -= ( From adef207e4f6565bf75f88f56c99f0323de010fc4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:03:50 +0000 Subject: [PATCH 476/518] fixed missing comma --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 1c03935f..9970c38b 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -100,7 +100,7 @@ def _pobtas( gemm( L_lower_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], alpha=-1.0, beta=1.0 ) ) From 39c898fcea74d9dd4ac6b04ccc9e84b2e2090235 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:04:40 +0000 Subject: [PATCH 477/518] removed gemm improvemnt --- src/serinv/algs/pobtas.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 9970c38b..936f78ca 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -97,12 +97,8 @@ def _pobtas( ) B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] = ( - gemm( - L_lower_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize], - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], - alpha=-1.0, beta=1.0 - ) + L_lower_diagonal_blocks[i] + @ B[i * diag_blocksize : (i + 1) * diag_blocksize] ) B[-arrow_blocksize:] -= ( From 39ca52be9c7e4fa256fa8600d7b856909a0410de Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:05:39 +0000 Subject: [PATCH 478/518] added missing minus --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 936f78ca..7150c85e 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -96,7 +96,7 @@ def _pobtas( ) ) - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] = ( + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( L_lower_diagonal_blocks[i] @ B[i * diag_blocksize : (i + 1) * diag_blocksize] ) From 6f5ed8196fc84940574d6fb354c369fd921280c2 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:07:07 +0000 Subject: [PATCH 479/518] added more trsm --- src/serinv/algs/pobtas.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 7150c85e..46075010 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -110,7 +110,7 @@ def _pobtas( # In the case of the partial solve, we do not solve the last block and # arrow tip block of the RHS. B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[n_diag_blocks - 1], B[ (n_diag_blocks - 1) @@ -131,14 +131,14 @@ def _pobtas( ) # Y_{ndb+1} = L_{ndb+1,ndb+1}^{-1} (B_{ndb+1} - \Sigma_{i=1}^{ndb} L_{ndb+1,i} Y_{i) - B[-arrow_blocksize:] = la.solve_triangular( + B[-arrow_blocksize:] = trsm( L_arrow_tip_block[:], B[-arrow_blocksize:], lower=True ) elif trans == "T" or trans == "C": # ----- Backward substitution ----- if not partial: # X_{ndb+1} = L_{ndb+1,ndb+1}^{-T} (Y_{ndb+1}) - B[-arrow_blocksize:] = la.solve_triangular( + B[-arrow_blocksize:] = trsm( L_arrow_tip_block[:], B[-arrow_blocksize:], lower=True, @@ -147,7 +147,7 @@ def _pobtas( # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) B[-arrow_blocksize - diag_blocksize : -arrow_blocksize] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[-1], B[-arrow_blocksize - diag_blocksize : -arrow_blocksize] - L_lower_arrow_blocks[-1].conj().T @ B[-arrow_blocksize:], @@ -158,7 +158,7 @@ def _pobtas( for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( + B[i * diag_blocksize : (i + 1) * diag_blocksize] = trsm( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize] - L_lower_diagonal_blocks[i].conj().T From e7df35a9237ecde431d149b0c4a4cf98b238e2ab Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:35:49 +0000 Subject: [PATCH 480/518] changed all solve trinagular to trsm --- src/serinv/algs/pobtas.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 46075010..c66670d6 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -189,7 +189,7 @@ def _pobtas_permuted( if trans == "N": # ----- Forward substitution ----- for i in range(1, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( + B[i * diag_blocksize : (i + 1) * diag_blocksize] = trsm( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], lower=True, @@ -214,7 +214,7 @@ def _pobtas_permuted( elif trans == "T" or trans == "C": # ----- Backward substitution ----- for i in range(n_diag_blocks - 2, 0, -1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( + B[i * diag_blocksize : (i + 1) * diag_blocksize] = trsm( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize] - L_lower_diagonal_blocks[i].conj().T @@ -350,7 +350,7 @@ def _pobtas_streaming( # Solve current B block compute_stream.wait_event(h2d_diagonal_events[i % 2]) - B_d[i % 2] = cu_la.solve_triangular( + B_d[i % 2] = trsm( L_diagonal_blocks_d[i % 2], B_d[i % 2], lower=True, @@ -436,7 +436,7 @@ def _pobtas_streaming( # Solve last B block compute_stream.wait_event(h2d_diagonal_events[(n_diag_blocks - 1) % 2]) - B_d[(n_diag_blocks - 1) % 2] = cu_la.solve_triangular( + B_d[(n_diag_blocks - 1) % 2] = trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], B_d[(n_diag_blocks - 1) % 2], lower=True, @@ -467,7 +467,7 @@ def _pobtas_streaming( L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2] @ B_d[(n_diag_blocks - 1) % 2] ) - B_arrow_tip_d = cu_la.solve_triangular( + B_arrow_tip_d = trsm( L_arrow_tip_block_d, B_arrow_tip_d, lower=True ) @@ -517,7 +517,7 @@ def _pobtas_streaming( with compute_stream: # X_{ndb+1} = L_{ndb+1,ndb+1}^{-T} (Y_{ndb+1}) compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) - B_arrow_tip_d = cu_la.solve_triangular( + B_arrow_tip_d = trsm( L_arrow_tip_block_d, B_arrow_tip_d, lower=True, @@ -525,7 +525,7 @@ def _pobtas_streaming( ) # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) - B_previous_d[(n_diag_blocks - 1) % 2] = cu_la.solve_triangular( + B_previous_d[(n_diag_blocks - 1) % 2] = trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], B_d[(n_diag_blocks - 1) % 2] - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].conj().T @@ -594,7 +594,7 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) - B_previous_d[i % 2] = cu_la.solve_triangular( + B_previous_d[i % 2] = trsm( L_diagonal_blocks_d[i % 2], B_d[i % 2] - L_lower_diagonal_blocks_d[i % 2].conj().T From 6b072bdc62315d5e88d6f7dbd712abe3566c20bd Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:38:42 +0000 Subject: [PATCH 481/518] added first gemm --- src/serinv/algs/pobtas.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index c66670d6..32f10512 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -96,9 +96,17 @@ def _pobtas( ) ) - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( - L_lower_diagonal_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + #B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( + # L_lower_diagonal_blocks[i] + # @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + #) + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + trans='N', alpha=-1.0, beta=1.0 + ) ) B[-arrow_blocksize:] -= ( From e5da93fcf29a49055565f6e693701ec0c7fd67cd Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:39:27 +0000 Subject: [PATCH 482/518] fixed keyword --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 32f10512..adf92f44 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -105,7 +105,7 @@ def _pobtas( L_lower_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], - trans='N', alpha=-1.0, beta=1.0 + alpha=-1.0, beta=1.0 ) ) From d9539b17b120371bdceef623a1fb555921c11172 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:40:42 +0000 Subject: [PATCH 483/518] switched a and b --- src/serinv/algs/pobtas.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index adf92f44..1a805e8f 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -102,8 +102,9 @@ def _pobtas( #) B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] = ( gemm( - L_lower_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + L_lower_diagonal_blocks[i], B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], alpha=-1.0, beta=1.0 ) From c90226b2e9f74670b7adaf1eab2e54447545cefd Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:41:37 +0000 Subject: [PATCH 484/518] reverted switch --- src/serinv/algs/pobtas.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 1a805e8f..adf92f44 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -102,9 +102,8 @@ def _pobtas( #) B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] = ( gemm( - - B[i * diag_blocksize : (i + 1) * diag_blocksize], L_lower_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], alpha=-1.0, beta=1.0 ) From de9114704a36996cdc19b35679abf702f107f1a6 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:43:43 +0000 Subject: [PATCH 485/518] fixed gemm validation --- src/serinv/block_primitive/gemm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 8924cb2b..6bc907e7 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -44,19 +44,19 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov c1 = _asarray_validated(c, check_finite=check_finite) if not trans_a and not trans_b: - if a1.shape[0] != b1.shape[0]: + if a1.shape[1] != b1.shape[0]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') elif trans_a and not trans_b: - if a1.shape[1] != b1.shape[0]: + if a1.shape[0] != b1.shape[0]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') elif not trans_a and trans_b: - if a1.shape[0] != b1.shape[1]: + if a1.shape[1] != b1.shape[1]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') else: - if a1.shape[1] != b1.shape[1]: + if a1.shape[0] != b1.shape[1]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') if beta != 0 and c1 is None: From 5db201e0159cf25d6b13f815fbe9f5e19c71f81d Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:45:28 +0000 Subject: [PATCH 486/518] error message to show which shape check is wrong --- src/serinv/block_primitive/gemm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 6bc907e7..7e94b9a1 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -45,19 +45,19 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov if not trans_a and not trans_b: if a1.shape[1] != b1.shape[0]: - raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (1,0)') elif trans_a and not trans_b: if a1.shape[0] != b1.shape[0]: - raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (0,0)') elif not trans_a and trans_b: if a1.shape[1] != b1.shape[1]: - raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (1,1)') else: if a1.shape[0] != b1.shape[1]: - raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (0,1)') if beta != 0 and c1 is None: raise ValueError('expected C matrix') From 17490b28d7deb22fe0cdd3ea863794009cd77c25 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:46:41 +0000 Subject: [PATCH 487/518] swapped shapes --- src/serinv/block_primitive/gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 7e94b9a1..e5976e9c 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -48,7 +48,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (1,0)') elif trans_a and not trans_b: - if a1.shape[0] != b1.shape[0]: + if a1.shape[0] != b1.shape[1]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (0,0)') elif not trans_a and trans_b: @@ -56,7 +56,7 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (1,1)') else: - if a1.shape[0] != b1.shape[1]: + if a1.shape[0] != b1.shape[0]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (0,1)') if beta != 0 and c1 is None: From f210506c23f575fb89400c0e27c42c84ceaa9b71 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:51:09 +0000 Subject: [PATCH 488/518] fixing shape check --- src/serinv/block_primitive/gemm.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index e5976e9c..34a8d50d 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -43,20 +43,27 @@ def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, ov else: c1 = _asarray_validated(c, check_finite=check_finite) - if not trans_a and not trans_b: + transa = True + transb = True + if trans_a == 'N': + transa = False + if trans_b == 'N': + transb = False + + if not transa and not transb: if a1.shape[1] != b1.shape[0]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (1,0)') - elif trans_a and not trans_b: - if a1.shape[0] != b1.shape[1]: + elif transa and not transb: + if a1.shape[0] != b1.shape[0]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (0,0)') - elif not trans_a and trans_b: + elif not transa and transb: if a1.shape[1] != b1.shape[1]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (1,1)') else: - if a1.shape[0] != b1.shape[0]: + if a1.shape[0] != b1.shape[1]: raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (0,1)') if beta != 0 and c1 is None: From 19a2bfb2cbf7162c8733283acb39e860300ef3ca Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:52:24 +0000 Subject: [PATCH 489/518] next gemm --- src/serinv/algs/pobtas.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index adf92f44..a20ac7b3 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -110,8 +110,12 @@ def _pobtas( ) B[-arrow_blocksize:] -= ( - L_lower_arrow_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + gemm( + L_lower_arrow_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[-arrow_blocksize:], + alpha=-1.0, beta=1.0 + ) ) if not partial: From f2ca36ca7e3a0c10110d98387c3cc85ae510839a Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:53:15 +0000 Subject: [PATCH 490/518] removed minus --- src/serinv/algs/pobtas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index a20ac7b3..c0cdf6c8 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -109,7 +109,7 @@ def _pobtas( ) ) - B[-arrow_blocksize:] -= ( + B[-arrow_blocksize:] = ( gemm( L_lower_arrow_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], From 390022e821bcb9eb95dd6c4f970603273b0566da Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:54:39 +0000 Subject: [PATCH 491/518] another gemm --- src/serinv/algs/pobtas.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index c0cdf6c8..981a9629 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -133,19 +133,24 @@ def _pobtas( ) ) - B[-arrow_blocksize:] -= ( - L_lower_arrow_blocks[-1] - @ B[ - (n_diag_blocks - 1) - * diag_blocksize : n_diag_blocks - * diag_blocksize - ] + B[-arrow_blocksize:] = ( + gemm( + L_lower_arrow_blocks[-1], + B[ + (n_diag_blocks - 1) + * diag_blocksize : n_diag_blocks + * diag_blocksize + ], + B[-arrow_blocksize:], + alpha=-1.0, beta=1.0 + ) ) # Y_{ndb+1} = L_{ndb+1,ndb+1}^{-1} (B_{ndb+1} - \Sigma_{i=1}^{ndb} L_{ndb+1,i} Y_{i) B[-arrow_blocksize:] = trsm( L_arrow_tip_block[:], B[-arrow_blocksize:], lower=True ) + elif trans == "T" or trans == "C": # ----- Backward substitution ----- if not partial: From f12632fb5dc14160aaafe3ca1439b170b8041393 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 09:57:45 +0000 Subject: [PATCH 492/518] gemm in trsm --- src/serinv/algs/pobtas.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 981a9629..27ac5e40 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -150,7 +150,7 @@ def _pobtas( B[-arrow_blocksize:] = trsm( L_arrow_tip_block[:], B[-arrow_blocksize:], lower=True ) - + elif trans == "T" or trans == "C": # ----- Backward substitution ----- if not partial: @@ -166,8 +166,12 @@ def _pobtas( B[-arrow_blocksize - diag_blocksize : -arrow_blocksize] = ( trsm( L_diagonal_blocks[-1], - B[-arrow_blocksize - diag_blocksize : -arrow_blocksize] - - L_lower_arrow_blocks[-1].conj().T @ B[-arrow_blocksize:], + gemm( + L_lower_arrow_blocks[-1], + B[-arrow_blocksize:], + B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ), lower=True, trans="C", ) From 138b5995805b81da22f0300a8b7d54af13939fb7 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 10:02:12 +0000 Subject: [PATCH 493/518] attempt at disambiguating last solve in normal pobtas --- src/serinv/algs/pobtas.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 27ac5e40..dc2845b6 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -179,12 +179,27 @@ def _pobtas( for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + L_lower_arrow_blocks[i], + B[-arrow_blocksize:], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + B[i * diag_blocksize : (i + 1) * diag_blocksize] = trsm( L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize] - - L_lower_diagonal_blocks[i].conj().T - @ B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] - - L_lower_arrow_blocks[i].conj().T @ B[-arrow_blocksize:], + B[i * diag_blocksize : (i + 1) * diag_blocksize], lower=True, trans="C", ) From a6b11e24517ec5d481e37936f92eb9b44cd613e0 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 10:03:21 +0000 Subject: [PATCH 494/518] cleaning up mess --- src/serinv/algs/pobtas.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index dc2845b6..5a729681 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -164,14 +164,18 @@ def _pobtas( # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) B[-arrow_blocksize - diag_blocksize : -arrow_blocksize] = ( - trsm( - L_diagonal_blocks[-1], - gemm( + gemm( L_lower_arrow_blocks[-1], B[-arrow_blocksize:], B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], trans_a='C', alpha=-1.0, beta=1.0 - ), + ) + ) + + B[-arrow_blocksize - diag_blocksize : -arrow_blocksize] = ( + trsm( + L_diagonal_blocks[-1], + B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], lower=True, trans="C", ) From 7f1418e65e6924d46f7a643e99c9504576d0f62f Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 10:14:28 +0000 Subject: [PATCH 495/518] added some gemms --- src/serinv/algs/pobtas.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 5a729681..f289d733 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -236,20 +236,33 @@ def _pobtas_permuted( ) # Update the next RHS block - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( - L_lower_diagonal_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + alpha=-1.0, beta=1.0 + ) ) # Update the first RHS block (permutation-linked) - B[:diag_blocksize] -= ( - buffer[i] @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + B[:diag_blocksize] = ( + gemm( + buffer[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[:diag_blocksize], + alpha=-1.0, beta=1.0 + ) ) # Update the tip RHS block - B[-arrow_blocksize:] -= ( - L_lower_arrow_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + B[-arrow_blocksize:] = ( + gemm( + L_lower_arrow_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[-arrow_blocksize:], + alpha=-1.0, beta=1.0 + ) ) elif trans == "T" or trans == "C": # ----- Backward substitution ----- From bdc9f634865bc080337127cd37661ae8c7e2bba3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 10:18:53 +0000 Subject: [PATCH 496/518] disambiguate big solve block --- src/serinv/algs/pobtas.py | 63 +++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index f289d733..115e23ec 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -147,8 +147,12 @@ def _pobtas( ) # Y_{ndb+1} = L_{ndb+1,ndb+1}^{-1} (B_{ndb+1} - \Sigma_{i=1}^{ndb} L_{ndb+1,i} Y_{i) - B[-arrow_blocksize:] = trsm( - L_arrow_tip_block[:], B[-arrow_blocksize:], lower=True + B[-arrow_blocksize:] = ( + trsm( + L_arrow_tip_block[:], + B[-arrow_blocksize:], + lower=True + ) ) elif trans == "T" or trans == "C": @@ -201,11 +205,13 @@ def _pobtas( ) ) - B[i * diag_blocksize : (i + 1) * diag_blocksize] = trsm( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, - trans="C", + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True, + trans="C", + ) ) else: raise ValueError(f"Invalid transpose argument: {trans}.") @@ -267,15 +273,40 @@ def _pobtas_permuted( elif trans == "T" or trans == "C": # ----- Backward substitution ----- for i in range(n_diag_blocks - 2, 0, -1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = trsm( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize] - - L_lower_diagonal_blocks[i].conj().T - @ B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] - - L_lower_arrow_blocks[i].conj().T @ B[-arrow_blocksize:] - - buffer[i].conj().T @ B[:diag_blocksize], - lower=True, - trans="C", + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + L_lower_arrow_blocks[i], + B[-arrow_blocksize:], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + buffer[i], + B[:diag_blocksize], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True, + trans="C", + ) ) else: raise ValueError(f"Invalid transpose argument: {trans}.") From f6c29c4003215c5706ca1792622bfb5932ff6f2b Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 10:21:05 +0000 Subject: [PATCH 497/518] reformattign and adding first gemm to streaming --- src/serinv/algs/pobtas.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 115e23ec..075caa7f 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -434,10 +434,12 @@ def _pobtas_streaming( # Solve current B block compute_stream.wait_event(h2d_diagonal_events[i % 2]) - B_d[i % 2] = trsm( - L_diagonal_blocks_d[i % 2], - B_d[i % 2], - lower=True, + B_d[i % 2] = ( + trsm( + L_diagonal_blocks_d[i % 2], + B_d[i % 2], + lower=True, + ) ) compute_current_B_events[i % 2].record(stream=compute_stream) @@ -468,7 +470,14 @@ def _pobtas_streaming( # Update next B block compute_stream.wait_event(h2d_B_events[(i + 1) % 2]) - B_d[(i + 1) % 2] -= L_lower_diagonal_blocks_d[i % 2] @ B_d[i % 2] + B_d[(i + 1) % 2] = ( + gemm( + L_lower_diagonal_blocks_d[i % 2], + B_d[i % 2], + B_d[(i + 1) % 2], + alpha=-1.0, beta=1.0 + ) + ) compute_next_B_events[i % 2].record(stream=compute_stream) From 1e7dae8bea31f7a05142b5be807d3f0c07bff783 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 10:24:23 +0000 Subject: [PATCH 498/518] reformatting and adding more gemm --- src/serinv/algs/pobtas.py | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 075caa7f..e0c69756 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -494,7 +494,14 @@ def _pobtas_streaming( # Update arrow tip compute_stream.wait_event(h2d_arrow_events[i % 2]) - B_arrow_tip_d -= L_lower_arrow_blocks_d[i % 2] @ B_d[i % 2] + B_arrow_tip_d = ( + gemm( + L_lower_arrow_blocks_d[i % 2], + B_d[i % 2], + B_arrow_tip_d, + alpha=-1.0, beta=1.0 + ) + ) compute_arrow_B_events[i % 2].record(stream=compute_stream) @@ -529,10 +536,12 @@ def _pobtas_streaming( # Solve last B block compute_stream.wait_event(h2d_diagonal_events[(n_diag_blocks - 1) % 2]) - B_d[(n_diag_blocks - 1) % 2] = trsm( - L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], - B_d[(n_diag_blocks - 1) % 2], - lower=True, + B_d[(n_diag_blocks - 1) % 2] = ( + trsm( + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2], + lower=True, + ) ) compute_partial_events[0].record(stream=compute_stream) @@ -556,12 +565,20 @@ def _pobtas_streaming( # Solve arrow tip compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) - B_arrow_tip_d -= ( - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2] - @ B_d[(n_diag_blocks - 1) % 2] + B_arrow_tip_d = ( + gemm( + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2], + B_arrow_tip_d, + alpha=-1.0, beta=1.0 + ) ) - B_arrow_tip_d = trsm( - L_arrow_tip_block_d, B_arrow_tip_d, lower=True + B_arrow_tip_d = ( + trsm( + L_arrow_tip_block_d, + B_arrow_tip_d, + lower=True + ) ) compute_partial_events[1].record(stream=compute_stream) From 67d93bd18134ecac7a01f93ee684b19e0686eb4c Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 10:27:48 +0000 Subject: [PATCH 499/518] added even more gemm and reformatting --- src/serinv/algs/pobtas.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index e0c69756..4193be6b 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -627,21 +627,32 @@ def _pobtas_streaming( with compute_stream: # X_{ndb+1} = L_{ndb+1,ndb+1}^{-T} (Y_{ndb+1}) compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) - B_arrow_tip_d = trsm( - L_arrow_tip_block_d, - B_arrow_tip_d, - lower=True, - trans="C", + B_arrow_tip_d = ( + trsm( + L_arrow_tip_block_d, + B_arrow_tip_d, + lower=True, + trans="C", + ) ) # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) - B_previous_d[(n_diag_blocks - 1) % 2] = trsm( - L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], - B_d[(n_diag_blocks - 1) % 2] - - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].conj().T - @ B_arrow_tip_d, - lower=True, - trans="C", + B_previous_d[(n_diag_blocks - 1) % 2] = ( + gemm( + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2], + B_arrow_tip_d, + B_d[(n_diag_blocks - 1) % 2], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B_previous_d[(n_diag_blocks - 1) % 2] = ( + trsm( + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2], + lower=True, + trans="C", + ) ) compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) From 7ce90bf5693ec3d92b7b0eb42c00733b20c0bba8 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 10:32:30 +0000 Subject: [PATCH 500/518] diambiguate and add gemm --- src/serinv/algs/pobtas.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 4193be6b..7a41c339 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -715,14 +715,31 @@ def _pobtas_streaming( compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) - B_previous_d[i % 2] = trsm( - L_diagonal_blocks_d[i % 2], - B_d[i % 2] - - L_lower_diagonal_blocks_d[i % 2].conj().T - @ B_previous_d[(i - 1) % 2] - - L_lower_arrow_blocks_d[i % 2].conj().T @ B_arrow_tip_d, - lower=True, - trans="C", + B_d[i % 2] = ( + gemm( + L_lower_diagonal_blocks_d[i % 2], + B_previous_d[(i - 1) % 2], + B_d[i % 2], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B_d[i % 2] = ( + gemm( + L_lower_arrow_blocks_d[i % 2], + B_arrow_tip_d, + B_d[i % 2], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B_previous_d[i % 2] = ( + trsm( + L_diagonal_blocks_d[i % 2], + B_d[i % 2], + lower=True, + trans="C", + ) ) compute_B_events[i % 2].record(compute_stream) From b331e0bbf95032e572105c6f084b25f2c37b34b3 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 10:39:02 +0000 Subject: [PATCH 501/518] added trsm to pobtasi --- src/serinv/algs/pobtasi.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/serinv/algs/pobtasi.py b/src/serinv/algs/pobtasi.py index 7e572d19..af63e016 100644 --- a/src/serinv/algs/pobtasi.py +++ b/src/serinv/algs/pobtasi.py @@ -6,6 +6,8 @@ _get_module_from_str, ) +from serinv.block_primitive import trsm, gemm, syherk + def pobtasi( L_diagonal_blocks: ArrayLike, @@ -112,7 +114,7 @@ def _pobtasi( Identity = xp.eye(L_diagonal_blocks.shape[1]) if invert_last_block: - L_last_blk_inv = la.solve_triangular( + L_last_blk_inv = trsm( L_arrow_tip_block[:, :], xp.eye(L_arrow_tip_block.shape[0]), lower=True ) @@ -121,7 +123,7 @@ def _pobtasi( # Backward block-selected inversion L_lower_arrow_blocks_i[:, :] = L_lower_arrow_blocks[-1, :, :] - L_blk_inv = la.solve_triangular( + L_blk_inv = trsm( L_diagonal_blocks[-1, :, :], Identity, lower=True, @@ -142,7 +144,7 @@ def _pobtasi( L_lower_diagonal_blocks_i[:, :] = L_lower_diagonal_blocks[i, :, :] L_lower_arrow_blocks_i[:, :] = L_lower_arrow_blocks[i, :, :] - L_blk_inv = la.solve_triangular( + L_blk_inv = trsm( L_diagonal_blocks[i, :, :], Identity, lower=True, @@ -201,7 +203,7 @@ def _pobtasi_permuted( L_lower_arrow_blocks_temp[:, :] = L_lower_arrow_blocks[i, :, :] buffer_temp[:, :] = buffer[i, :, :] - L_inv_temp[:, :] = la.solve_triangular( + L_inv_temp[:, :] = trsm( L_diagonal_blocks[i, :, :], xp.eye(diag_blocksize), lower=True, @@ -321,7 +323,7 @@ def _pobtasi_streaming( with compute_stream: if invert_last_block: # X_{ndb+1, ndb+1} = L_{ndb+1, ndb}^{-T} L_{ndb+1, ndb}^{-1} - L_last_blk_inv_d = cu_la.solve_triangular( + L_last_blk_inv_d = cu_trsm( L_arrow_tip_block_d[:, :], cp.eye(L_arrow_tip_block.shape[0]), lower=True, @@ -356,7 +358,7 @@ def _pobtasi_streaming( compute_stream.wait_event(h2d_diagonal_events[(n_diag_blocks - 1) % 2]) if invert_last_block: # X_{ndb+1, ndb} = -X_{ndb+1, ndb+1} L_{ndb+1, ndb} L_{ndb, ndb}^{-1} - L_blk_inv_d = cu_la.solve_triangular( + L_blk_inv_d = cu_trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], Identity, lower=True, @@ -434,7 +436,7 @@ def _pobtasi_streaming( with compute_stream: compute_stream.wait_event(h2d_diagonal_events[i % 2]) - L_blk_inv_d = cu_la.solve_triangular( + L_blk_inv_d = cu_trsm( L_diagonal_blocks_d[i % 2, :, :], Identity, lower=True, @@ -632,7 +634,7 @@ def _pobtasi_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_diagonal_events[i % 2]) - L_inv_temp_d[:, :] = cu_la.solve_triangular( + L_inv_temp_d[:, :] = cu_trsm( L_diagonal_blocks_d[i % 2, :, :], cp.eye(diag_blocksize), lower=True, From 5360379409d83b0793255d3352e1ef5f30df7b88 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 10:45:12 +0000 Subject: [PATCH 502/518] fixed pobtasi trsm --- src/serinv/algs/pobtasi.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/serinv/algs/pobtasi.py b/src/serinv/algs/pobtasi.py index af63e016..410dfb5f 100644 --- a/src/serinv/algs/pobtasi.py +++ b/src/serinv/algs/pobtasi.py @@ -323,7 +323,7 @@ def _pobtasi_streaming( with compute_stream: if invert_last_block: # X_{ndb+1, ndb+1} = L_{ndb+1, ndb}^{-T} L_{ndb+1, ndb}^{-1} - L_last_blk_inv_d = cu_trsm( + L_last_blk_inv_d = trsm( L_arrow_tip_block_d[:, :], cp.eye(L_arrow_tip_block.shape[0]), lower=True, @@ -358,7 +358,7 @@ def _pobtasi_streaming( compute_stream.wait_event(h2d_diagonal_events[(n_diag_blocks - 1) % 2]) if invert_last_block: # X_{ndb+1, ndb} = -X_{ndb+1, ndb+1} L_{ndb+1, ndb} L_{ndb, ndb}^{-1} - L_blk_inv_d = cu_trsm( + L_blk_inv_d = trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], Identity, lower=True, @@ -436,7 +436,7 @@ def _pobtasi_streaming( with compute_stream: compute_stream.wait_event(h2d_diagonal_events[i % 2]) - L_blk_inv_d = cu_trsm( + L_blk_inv_d = trsm( L_diagonal_blocks_d[i % 2, :, :], Identity, lower=True, @@ -634,7 +634,7 @@ def _pobtasi_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_diagonal_events[i % 2]) - L_inv_temp_d[:, :] = cu_trsm( + L_inv_temp_d[:, :] = trsm( L_diagonal_blocks_d[i % 2, :, :], cp.eye(diag_blocksize), lower=True, From 74299071b97be2d8b1169d570b485500f72f9a72 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 10:51:49 +0000 Subject: [PATCH 503/518] added print and error to see how multiplication works --- src/serinv/algs/pobtasi.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/serinv/algs/pobtasi.py b/src/serinv/algs/pobtasi.py index 410dfb5f..d07c851c 100644 --- a/src/serinv/algs/pobtasi.py +++ b/src/serinv/algs/pobtasi.py @@ -152,10 +152,13 @@ def _pobtasi( # --- Off-diagonal block part --- # X_{i+1, i} = (-X_{i+1, i+1} L_{i+1, i} - X_{ndb+1, i+1}^{T} L_{ndb+1, i}) L_{i, i}^{-1} + print(X_diagonal_blocks[i + 1, :, :]) X_lower_diagonal_blocks[i, :, :] = ( -X_diagonal_blocks[i + 1, :, :] @ L_lower_diagonal_blocks_i[:, :] - X_arrow_bottom_blocks[i + 1, :, :].conj().T @ L_lower_arrow_blocks_i[:, :] ) @ L_blk_inv + print(X_diagonal_blocks[i + 1, :, :]) + raise ValueError("TEST") # --- Arrowhead part --- # X_{ndb+1, i} = (- X_{ndb+1, i+1} L_{i+1, i} - X_{ndb+1, ndb+1} L_{ndb+1, i}) L_{i, i}^{-1} From a927d9d3dfb675eaff01db21bfa19a9b57d5e6d5 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 10:53:13 +0000 Subject: [PATCH 504/518] removed tests --- src/serinv/algs/pobtasi.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/serinv/algs/pobtasi.py b/src/serinv/algs/pobtasi.py index d07c851c..410dfb5f 100644 --- a/src/serinv/algs/pobtasi.py +++ b/src/serinv/algs/pobtasi.py @@ -152,13 +152,10 @@ def _pobtasi( # --- Off-diagonal block part --- # X_{i+1, i} = (-X_{i+1, i+1} L_{i+1, i} - X_{ndb+1, i+1}^{T} L_{ndb+1, i}) L_{i, i}^{-1} - print(X_diagonal_blocks[i + 1, :, :]) X_lower_diagonal_blocks[i, :, :] = ( -X_diagonal_blocks[i + 1, :, :] @ L_lower_diagonal_blocks_i[:, :] - X_arrow_bottom_blocks[i + 1, :, :].conj().T @ L_lower_arrow_blocks_i[:, :] ) @ L_blk_inv - print(X_diagonal_blocks[i + 1, :, :]) - raise ValueError("TEST") # --- Arrowhead part --- # X_{ndb+1, i} = (- X_{ndb+1, i+1} L_{i+1, i} - X_{ndb+1, ndb+1} L_{ndb+1, i}) L_{i, i}^{-1} From e881574573982b4fb20fe1f09d2b2dd977c8946d Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 11:02:41 +0000 Subject: [PATCH 505/518] check if gemm can be applied in pobtasi --- src/serinv/algs/pobtasi.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/serinv/algs/pobtasi.py b/src/serinv/algs/pobtasi.py index 410dfb5f..c7313054 100644 --- a/src/serinv/algs/pobtasi.py +++ b/src/serinv/algs/pobtasi.py @@ -131,7 +131,11 @@ def _pobtasi( # X_{ndb+1, ndb} = -X_{ndb+1, ndb+1} L_{ndb+1, ndb} L_{ndb, ndb}^{-1} X_arrow_bottom_blocks[-1, :, :] = ( - -X_arrow_tip_block[:, :] @ L_lower_arrow_blocks_i[:, :] @ L_blk_inv + gemm( + X_arrow_tip_block[:, :], + L_lower_arrow_blocks_i[:, :], + alpha=-1.0 + ) @ L_blk_inv ) # X_{ndb, ndb} = (L_{ndb, ndb}^{-T} - X_{ndb+1, ndb}^{T} L_{ndb+1, ndb}) L_{ndb, ndb}^{-1} From 5d7800e9eec005afd258c389267dcc214cc00e9b Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 11:11:58 +0000 Subject: [PATCH 506/518] removed gemm from pobtasi --- src/serinv/algs/pobtasi.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/serinv/algs/pobtasi.py b/src/serinv/algs/pobtasi.py index c7313054..410dfb5f 100644 --- a/src/serinv/algs/pobtasi.py +++ b/src/serinv/algs/pobtasi.py @@ -131,11 +131,7 @@ def _pobtasi( # X_{ndb+1, ndb} = -X_{ndb+1, ndb+1} L_{ndb+1, ndb} L_{ndb, ndb}^{-1} X_arrow_bottom_blocks[-1, :, :] = ( - gemm( - X_arrow_tip_block[:, :], - L_lower_arrow_blocks_i[:, :], - alpha=-1.0 - ) @ L_blk_inv + -X_arrow_tip_block[:, :] @ L_lower_arrow_blocks_i[:, :] @ L_blk_inv ) # X_{ndb, ndb} = (L_{ndb, ndb}^{-T} - X_{ndb+1, ndb}^{T} L_{ndb+1, ndb}) L_{ndb, ndb}^{-1} From 89b0e070b4407a338c945991e53ca77cff73e94f Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 11:18:33 +0000 Subject: [PATCH 507/518] removed print at end of pobtas test --- tests/tests_algs/regular/tests_bta/test_pobtas.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index ffc290c2..61b81303 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -116,7 +116,5 @@ def test_pobtas( trans="C", device_streaming=True if array_type == "streaming" else False, ) - print("===") - print(X_ref) assert xp.allclose(B, X_ref) From 885a8d6b1d8a06d3c47acc8ca7845bea277d8576 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 11:19:19 +0000 Subject: [PATCH 508/518] removed all leftover prints --- src/serinv/algs/pobtaf.py | 1 - tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 674ff796..2260e970 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -671,7 +671,6 @@ def _pobtaf_permuted_streaming( A_diagonal_blocks_d[i % 2, :, :], lower=True ) cp_diagonal_events[i % 2].record(stream=compute_stream) - print(L_diagonal_blocks_d[i % 2, :, :]) d2h_stream.wait_event(cp_diagonal_events[i % 2]) L_diagonal_blocks_d[i % 2, :, :].get( diff --git a/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py b/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py index e811d360..e4b9fc40 100644 --- a/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py +++ b/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py @@ -137,8 +137,6 @@ def test_pobtasi_permuted( _A_arrow_tip_block, ) - print(X_diagonal_blocks_ref) - print(_A_diagonal_blocks) # Verify that the reduced system is already correct assert xp.allclose(X_arrow_tip_block_ref, _A_arrow_tip_block) assert xp.allclose(X_diagonal_blocks_ref[0], _A_diagonal_blocks[0]) From 293922b160cf6836ff871ba501118c4092655a56 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 12:49:36 +0000 Subject: [PATCH 509/518] pobtf perf improvements --- src/serinv/algs/pobtaf.py | 1 - src/serinv/algs/pobtf.py | 53 +++++++++++++++++++++++---------------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 2260e970..4993a8b6 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -123,7 +123,6 @@ def _pobtaf( A_lower_diagonal_blocks[i, :, :], trans='C',lower=True, side=1 ) - ) diff --git a/src/serinv/algs/pobtf.py b/src/serinv/algs/pobtf.py index ed70433b..1e6f4fb8 100644 --- a/src/serinv/algs/pobtf.py +++ b/src/serinv/algs/pobtf.py @@ -7,6 +7,7 @@ _get_cholesky, ) +from serinv.block_primitive import trsm, gemm, syherk def pobtf( A_diagonal_blocks: ArrayLike, @@ -100,21 +101,21 @@ def _pobtf( # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - A_diagonal_blocks[i + 1, :, :] - - L_lower_diagonal_blocks[i, :, :] - @ L_lower_diagonal_blocks[i, :, :].conj().T + syherk( + L_lower_diagonal_blocks[i, :, :], + A_diagonal_blocks[i + 1, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) if factorize_last_block: @@ -145,18 +146,16 @@ def _pobtf_permuted( # Compute lower factors # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # L_{top, i} = A_{top, i} @ U{i, i}^{-1} buffer[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], buffer[i, :, :].conj().T, lower=True, @@ -168,20 +167,30 @@ def _pobtf_permuted( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - A_diagonal_blocks[i + 1, :, :] - - L_lower_diagonal_blocks[i, :, :] - @ L_lower_diagonal_blocks[i, :, :].conj().T + syherk( + L_lower_diagonal_blocks[i, :, :], + A_diagonal_blocks[i + 1, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - A_diagonal_blocks[0, :, :] - buffer[i, :, :] @ buffer[i, :, :].conj().T + syherk( + buffer[i, :, :], + A_diagonal_blocks[0, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer[i + 1, :, :] = ( - -buffer[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + buffer[i, :, :], + L_lower_diagonal_blocks[i, :, :], + trans_b='C', alpha=-1.0 + ) ) @@ -276,7 +285,7 @@ def _pobtf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -436,7 +445,7 @@ def _pobtf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -456,7 +465,7 @@ def _pobtf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, lower=True, From a27dca95c47223640082db80a46e996af1f15af4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 12:54:40 +0000 Subject: [PATCH 510/518] improvedd pobtf --- src/serinv/algs/pobtf.py | 46 +++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/src/serinv/algs/pobtf.py b/src/serinv/algs/pobtf.py index 1e6f4fb8..becc3009 100644 --- a/src/serinv/algs/pobtf.py +++ b/src/serinv/algs/pobtf.py @@ -287,11 +287,9 @@ def _pobtf_streaming( L_lower_diagonal_blocks_d[i % 2, :, :] = ( trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) compute_lower_events[i % 2].record(stream=compute_stream) @@ -313,9 +311,11 @@ def _pobtf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + syherk( + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) compute_lower_h2d_events[i % 2].record(stream=compute_stream) @@ -447,11 +447,9 @@ def _pobtf_permuted_streaming( L_lower_diagonal_blocks_d[i % 2, :, :] = ( trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) cp_lower_events[i % 2].record(stream=compute_stream) @@ -496,24 +494,34 @@ def _pobtf_permuted_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + # gemm instead of syherk because this somehow kept failing tests in a very weird way + # probably because both sides of the diagonal matrix are used somwhere in a relevant way + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer_d[(i + 1) % 2, :, :] = ( - -L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + trans_b='C', alpha=-1.0 + ) ) cp_lower_events_h2d_release[i % 2].record(stream=compute_stream) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_top_block_d[:, :] = ( - A_diagonal_top_block_d[:, :] - - L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T + syherk( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + A_diagonal_top_block_d[:, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=False + ) ) # --- Device 2 Host transfers --- From 4a2c175d7fd7dc4745d694bf576f37e8ed457f06 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 12:58:12 +0000 Subject: [PATCH 511/518] pobts normal improved --- src/serinv/algs/pobtasi.py | 2 +- src/serinv/algs/pobts.py | 52 ++++++++++++++++++++++++-------------- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/src/serinv/algs/pobtasi.py b/src/serinv/algs/pobtasi.py index 410dfb5f..4e6c2ab2 100644 --- a/src/serinv/algs/pobtasi.py +++ b/src/serinv/algs/pobtasi.py @@ -6,7 +6,7 @@ _get_module_from_str, ) -from serinv.block_primitive import trsm, gemm, syherk +from serinv.block_primitive import trsm def pobtasi( diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 34c8e74e..ba4d3100 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -7,6 +7,7 @@ _get_module_from_str, ) +from serinv.block_primitive import trsm, gemm def pobts( L_diagonal_blocks: ArrayLike, @@ -74,20 +75,24 @@ def _pobts( # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): # Y_{i} = L_{i,i}^{-1} (B_{i} - L_{i,i-1} Y_{i-1}) - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( + B[i * diag_blocksize : (i + 1) * diag_blocksize] = trsm( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], lower=True, ) - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( - L_lower_diagonal_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + alpha=-1.0, beta=1.0 + ) ) if not partial: B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[n_diag_blocks - 1], B[ (n_diag_blocks - 1) @@ -101,7 +106,7 @@ def _pobts( # ----- Backward substitution ----- if not partial: # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) - B[-diag_blocksize:] = la.solve_triangular( + B[-diag_blocksize:] = trsm( L_diagonal_blocks[-1], B[-diag_blocksize:], lower=True, @@ -110,13 +115,22 @@ def _pobts( for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize] - - L_lower_diagonal_blocks[i].conj().T - @ B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], - lower=True, - trans="C", + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True, + trans="C", + ) ) else: raise ValueError(f"Invalid transpose argument: {trans}.") @@ -137,7 +151,7 @@ def _pobts_permuted( if trans == "N": # ----- Forward substitution ----- for i in range(1, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( + B[i * diag_blocksize : (i + 1) * diag_blocksize] = trsm( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], lower=True, @@ -156,7 +170,7 @@ def _pobts_permuted( elif trans == "T" or trans == "C": # ----- Backward substitution ----- for i in range(n_diag_blocks - 2, 0, -1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( + B[i * diag_blocksize : (i + 1) * diag_blocksize] = trsm( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize] - L_lower_diagonal_blocks[i].conj().T @@ -234,7 +248,7 @@ def _pobts_streaming( # Solve first B block compute_stream.wait_event(h2d_events[1]) - B_previous_d[0] = cu_la.solve_triangular( + B_previous_d[0] = trsm( L_diagonal_blocks_d[0], B_d[0], lower=True, @@ -266,7 +280,7 @@ def _pobts_streaming( compute_stream.wait_event(h2d_events[(i + 1) % 2]) compute_stream.wait_event(d2h_events[(i + 1) % 2]) - B_previous_d[i % 2] = cu_la.solve_triangular( + B_previous_d[i % 2] = trsm( L_diagonal_blocks_d[i % 2], B_d[i % 2] - L_lower_diagonal_blocks_d[i % 2] @ B_previous_d[(i + 1) % 2], @@ -322,7 +336,7 @@ def _pobts_streaming( # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) - B_previous_d[(n_diag_blocks - 1) % 2] = cu_la.solve_triangular( + B_previous_d[(n_diag_blocks - 1) % 2] = trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], B_d[(n_diag_blocks - 1) % 2], lower=True, @@ -355,7 +369,7 @@ def _pobts_streaming( compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) - B_previous_d[i % 2] = cu_la.solve_triangular( + B_previous_d[i % 2] = trsm( L_diagonal_blocks_d[i % 2], B_d[i % 2] - L_lower_diagonal_blocks_d[i % 2].conj().T From 80ccc4c40f96bfba254f78cf2b2ac4d5d32f1744 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 13:01:30 +0000 Subject: [PATCH 512/518] pobts permuted improved --- src/serinv/algs/pobtas.py | 10 +++-- src/serinv/algs/pobts.py | 85 +++++++++++++++++++++++++++------------ 2 files changed, 65 insertions(+), 30 deletions(-) diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index 7a41c339..4d0f4c8f 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -235,10 +235,12 @@ def _pobtas_permuted( if trans == "N": # ----- Forward substitution ----- for i in range(1, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = trsm( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True, + ) ) # Update the next RHS block diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index ba4d3100..64f62c03 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -75,10 +75,12 @@ def _pobts( # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): # Y_{i} = L_{i,i}^{-1} (B_{i} - L_{i,i-1} Y_{i-1}) - B[i * diag_blocksize : (i + 1) * diag_blocksize] = trsm( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True + ) ) B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] = ( @@ -106,11 +108,13 @@ def _pobts( # ----- Backward substitution ----- if not partial: # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) - B[-diag_blocksize:] = trsm( - L_diagonal_blocks[-1], - B[-diag_blocksize:], - lower=True, - trans="C", + B[-diag_blocksize:] = ( + trsm( + L_diagonal_blocks[-1], + B[-diag_blocksize:], + lower=True, + trans="C", + ) ) for i in range(n_diag_blocks - 2, -1, -1): @@ -151,33 +155,62 @@ def _pobts_permuted( if trans == "N": # ----- Forward substitution ----- for i in range(1, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = trsm( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True, + ) ) # Update the next RHS block - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( - L_lower_diagonal_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + alpha=-1.0, beta=1.0 + ) ) # Update the first RHS block (permutation-linked) - B[:diag_blocksize] -= ( - buffer[i] @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + B[:diag_blocksize] = ( + gemm( + buffer[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[:diag_blocksize], + alpha=-1.0, beta=1.0 + ) ) + elif trans == "T" or trans == "C": # ----- Backward substitution ----- for i in range(n_diag_blocks - 2, 0, -1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = trsm( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize] - - L_lower_diagonal_blocks[i].conj().T - @ B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] - - buffer[i].conj().T @ B[:diag_blocksize], - lower=True, - trans="C", + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + buffer[i], + B[:diag_blocksize], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True, + trans="C", + ) ) else: raise ValueError(f"Invalid transpose argument: {trans}.") From a005d1c4b92c1462013ba43fdfc854219b417271 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 13:09:08 +0000 Subject: [PATCH 513/518] improved pobts --- src/serinv/algs/pobts.py | 65 +++++++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 64f62c03..684034d7 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -281,10 +281,12 @@ def _pobts_streaming( # Solve first B block compute_stream.wait_event(h2d_events[1]) - B_previous_d[0] = trsm( - L_diagonal_blocks_d[0], - B_d[0], - lower=True, + B_previous_d[0] = ( + trsm( + L_diagonal_blocks_d[0], + B_d[0], + lower=True, + ) ) compute_B_events[0].record(stream=compute_stream) @@ -313,11 +315,21 @@ def _pobts_streaming( compute_stream.wait_event(h2d_events[(i + 1) % 2]) compute_stream.wait_event(d2h_events[(i + 1) % 2]) - B_previous_d[i % 2] = trsm( - L_diagonal_blocks_d[i % 2], - B_d[i % 2] - - L_lower_diagonal_blocks_d[i % 2] @ B_previous_d[(i + 1) % 2], - lower=True, + B_d[i % 2] = ( + gemm( + L_lower_diagonal_blocks_d[i % 2], + B_previous_d[(i + 1) % 2], + B_d[i % 2], + alpha=-1.0, beta=1.0 + ) + ) + + B_previous_d[i % 2] = ( + trsm( + L_diagonal_blocks_d[i % 2], + B_d[i % 2], + lower=True, + ) ) compute_B_events[i % 2].record(compute_stream) @@ -369,11 +381,13 @@ def _pobts_streaming( # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) - B_previous_d[(n_diag_blocks - 1) % 2] = trsm( - L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], - B_d[(n_diag_blocks - 1) % 2], - lower=True, - trans="C", + B_previous_d[(n_diag_blocks - 1) % 2] = ( + trsm( + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2], + lower=True, + trans="C", + ) ) compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) @@ -402,13 +416,22 @@ def _pobts_streaming( compute_stream.wait_event(h2d_events[(i - 1) % 2]) compute_stream.wait_event(d2h_events[(i - 1) % 2]) - B_previous_d[i % 2] = trsm( - L_diagonal_blocks_d[i % 2], - B_d[i % 2] - - L_lower_diagonal_blocks_d[i % 2].conj().T - @ B_previous_d[(i - 1) % 2], - lower=True, - trans="C", + B_d[i % 2] = ( + gemm( + L_lower_diagonal_blocks_d[i % 2], + B_previous_d[(i - 1) % 2], + B_d[i % 2], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B_previous_d[i % 2] = ( + trsm( + L_diagonal_blocks_d[i % 2], + B_d[i % 2], + lower=True, + trans="C", + ) ) compute_B_events[i % 2].record(compute_stream) From 9936df6f732a8668a716ac95f048a844a61194a4 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 13:10:18 +0000 Subject: [PATCH 514/518] improved pobtsi --- src/serinv/algs/pobtsi.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/serinv/algs/pobtsi.py b/src/serinv/algs/pobtsi.py index 0d2d1e6a..0d981716 100644 --- a/src/serinv/algs/pobtsi.py +++ b/src/serinv/algs/pobtsi.py @@ -92,7 +92,7 @@ def _pobtsi( Identity = xp.eye(L_diagonal_blocks.shape[1]) if invert_last_block: - L_blk_inv = la.solve_triangular( + L_blk_inv = trsm( L_diagonal_blocks[-1, :, :], Identity, lower=True, @@ -104,7 +104,7 @@ def _pobtsi( for i in range(n_diag_blocks - 2, -1, -1): L_lower_diagonal_blocks_i[:, :] = L_lower_diagonal_blocks[i, :, :] - L_blk_inv = la.solve_triangular( + L_blk_inv = trsm( L_diagonal_blocks[i, :, :], Identity, lower=True, @@ -148,7 +148,7 @@ def _pobtsi_permuted( L_lower_diagonal_blocks_temp[:, :] = L_lower_diagonal_blocks[i, :, :] buffer_temp[:, :] = buffer[i, :, :] - L_inv_temp[:, :] = la.solve_triangular( + L_inv_temp[:, :] = trsm( L_diagonal_blocks[i, :, :], xp.eye(diag_blocksize), lower=True, @@ -245,7 +245,7 @@ def _pobtsi_streaming( compute_stream.wait_event(h2d_diagonal_events[(n_diag_blocks - 1) % 2]) if invert_last_block: # X_{ndb+1, ndb} = -X_{ndb+1, ndb+1} L_{ndb+1, ndb} L_{ndb, ndb}^{-1} - L_blk_inv_d = cu_la.solve_triangular( + L_blk_inv_d = trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], Identity, lower=True, @@ -289,7 +289,7 @@ def _pobtsi_streaming( with compute_stream: compute_stream.wait_event(h2d_diagonal_events[i % 2]) - L_blk_inv_d = cu_la.solve_triangular( + L_blk_inv_d = trsm( L_diagonal_blocks_d[i % 2, :, :], Identity, lower=True, @@ -435,7 +435,7 @@ def _pobtsi_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_diagonal_events[i % 2]) - L_inv_temp_d[:, :] = cu_la.solve_triangular( + L_inv_temp_d[:, :] = trsm( L_diagonal_blocks_d[i % 2, :, :], cp.eye(diag_blocksize), lower=True, From 9b5e95589bd3de53628efebd4ce9b34821ecf31d Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 13:10:55 +0000 Subject: [PATCH 515/518] added missing import --- src/serinv/algs/pobtsi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/serinv/algs/pobtsi.py b/src/serinv/algs/pobtsi.py index 0d981716..ec7294f3 100644 --- a/src/serinv/algs/pobtsi.py +++ b/src/serinv/algs/pobtsi.py @@ -6,6 +6,7 @@ _get_module_from_str, ) +from serinv.block_primitive import trsm def pobtsi( L_diagonal_blocks: ArrayLike, From 9a17bf10ed7acebab33d5129afff8a70e58cb1bf Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 16:42:33 +0000 Subject: [PATCH 516/518] added copyright --- src/serinv/block_primitive/__init__.py | 2 ++ src/serinv/block_primitive/gemm.py | 4 ++++ src/serinv/block_primitive/syherk.py | 4 ++++ src/serinv/block_primitive/trsm.py | 4 ++++ 4 files changed, 14 insertions(+) diff --git a/src/serinv/block_primitive/__init__.py b/src/serinv/block_primitive/__init__.py index 89b14a73..687c9c8f 100644 --- a/src/serinv/block_primitive/__init__.py +++ b/src/serinv/block_primitive/__init__.py @@ -1,3 +1,5 @@ +# Copyright 2023-2025 ETH Zurich. All rights reserved. + from serinv.block_primitive.gemm import gemm from serinv.block_primitive.trsm import trsm from serinv.block_primitive.syherk import syherk diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 34a8d50d..51579c6c 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -1,3 +1,7 @@ +# Copyright 2023-2025 ETH Zurich. All rights reserved. +# Forked and modified from cupy.cublas.gemm: https://github.com/cupy/cupy/blob/3a2c950d64ee707096bc7ca1bf0b953a08206384/cupy/cublas.py#L689 +# and scipy.linal.solve_triangular: https://github.com/scipy/scipy/blob/v1.15.3/scipy/linalg/_basic.py#L411 + from serinv import _get_module_from_array import numpy as np diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index d709f4b4..3755a83a 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -1,3 +1,7 @@ +# Copyright 2023-2025 ETH Zurich. All rights reserved. +# Forked and modified from cupy.cublas.syrk: https://github.com/cupy/cupy/blob/3a2c950d64ee707096bc7ca1bf0b953a08206384/cupy/cublas.py#L930 +# and scipy.linal.solve_triangular: https://github.com/scipy/scipy/blob/v1.15.3/scipy/linalg/_basic.py#L411 + from serinv import _get_module_from_array from serinv.block_primitive import gemm diff --git a/src/serinv/block_primitive/trsm.py b/src/serinv/block_primitive/trsm.py index 0c236350..2e660e39 100644 --- a/src/serinv/block_primitive/trsm.py +++ b/src/serinv/block_primitive/trsm.py @@ -1,3 +1,7 @@ +# Copyright 2023-2025 ETH Zurich. All rights reserved. +# Forked and modified from cupyx.linalg.solve_triangular: https://github.com/cupy/cupy/blob/3a2c950d64ee707096bc7ca1bf0b953a08206384/cupyx/scipy/linalg/_solve_triangular.py#L12 +# and scipy.linal.solve_triangular: https://github.com/scipy/scipy/blob/v1.15.3/scipy/linalg/_basic.py#L411 + import numpy as np from serinv import _get_module_from_array From 2406d59674f3c863476939bcad5831a767fcd148 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 16:50:24 +0000 Subject: [PATCH 517/518] removed unneeded matmul --- src/serinv/block_primitive/gemm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py index 51579c6c..c474c72a 100644 --- a/src/serinv/block_primitive/gemm.py +++ b/src/serinv/block_primitive/gemm.py @@ -5,7 +5,6 @@ from serinv import _get_module_from_array import numpy as np -from numpy.linalg import matmul from scipy.linalg.blas import get_blas_funcs from scipy.linalg._misc import _datacopied From 7d50b35446a6c2dfd6913f41c29d08549f678a63 Mon Sep 17 00:00:00 2001 From: 03szust Date: Wed, 18 Jun 2025 16:51:20 +0000 Subject: [PATCH 518/518] removed another matmul --- src/serinv/block_primitive/syherk.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py index 3755a83a..75367b3c 100644 --- a/src/serinv/block_primitive/syherk.py +++ b/src/serinv/block_primitive/syherk.py @@ -7,7 +7,6 @@ from serinv.block_primitive import gemm import numpy as np -from numpy.linalg import matmul from scipy.linalg.blas import get_blas_funcs from scipy.linalg._misc import _datacopied