diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index bab2a911..3981a749 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -4,6 +4,7 @@ from serinv import ( ArrayLike, _get_module_from_array, + _get_module_from_str, ) @@ -47,8 +48,14 @@ def pobtas( else: # Natural arrowhead if device_streaming: - raise NotImplementedError( - "Streaming is not implemented for the natural arrowhead." + _pobtas_streaming( + L_diagonal_blocks, + L_lower_diagonal_blocks, + L_lower_arrow_blocks, + L_arrow_tip_block, + B, + trans, + partial, ) else: _pobtas( @@ -216,3 +223,402 @@ def _pobtas_permuted( ) else: raise ValueError(f"Invalid transpose argument: {trans}.") + + +def _pobtas_streaming( + L_diagonal_blocks: ArrayLike, + L_lower_diagonal_blocks: ArrayLike, + L_lower_arrow_blocks: ArrayLike, + L_arrow_tip_block: ArrayLike, + B: ArrayLike, + trans: str, + partial: bool, +): + arr_module, _ = _get_module_from_array(arr=L_diagonal_blocks) + if arr_module.__name__ != "numpy": + raise TypeError( + "Host<->Device streaming only works when host-arrays are given." + ) + + cp, cu_la = _get_module_from_str(module_str="cupy") + + # Vars + diag_blocksize = L_diagonal_blocks.shape[1] + arrow_blocksize = L_lower_arrow_blocks.shape[1] + n_diag_blocks = L_diagonal_blocks.shape[0] + + # Streams + compute_stream = cp.cuda.Stream(non_blocking=True) + h2d_stream = cp.cuda.Stream(non_blocking=True) + d2h_stream = cp.cuda.Stream(non_blocking=True) + + # Device Buffers + # B Buffers + B_shape = B[-arrow_blocksize:] # block template + B_arrow_tip_d = cp.empty_like(B_shape) + + B_shape = B[0:diag_blocksize] + B_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype) + + # L Buffers + L_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_lower_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_lower_arrow_blocks_d = cp.empty( + (2, *L_lower_arrow_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_arrow_tip_block_d = cp.empty_like(L_arrow_tip_block) + + if trans == "N": + # ----- Forward substitution ----- + # Delete helper variable + del B_shape + + # Events + h2d_diagonal_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_lower_diagonal_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_arrow_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_B_events = [cp.cuda.Event(), cp.cuda.Event()] + + d2h_B_events = [cp.cuda.Event(), cp.cuda.Event()] + d2h_tip_events = [cp.cuda.Event(), cp.cuda.Event()] + + compute_current_B_events = [cp.cuda.Event(), cp.cuda.Event()] + compute_next_B_events = [cp.cuda.Event(), cp.cuda.Event()] + compute_arrow_B_events = [cp.cuda.Event(), cp.cuda.Event()] + + compute_partial_events = [cp.cuda.Event(), cp.cuda.Event()] + + # --- C: events + transfers --- + compute_current_B_events[1].record(stream=compute_stream) + compute_next_B_events[1].record(stream=compute_stream) + compute_arrow_B_events[1].record(stream=compute_stream) + + B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) + L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:], stream=h2d_stream) + + # --- H2D: transfers --- + B_d[0].set(arr=B[0:diag_blocksize], stream=h2d_stream) + h2d_B_events[0].record(stream=h2d_stream) + + L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream) + h2d_diagonal_events[0].record(stream=h2d_stream) + + L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[0], stream=h2d_stream) + h2d_arrow_events[0].record(stream=h2d_stream) + + # --- D2H: event --- + d2h_B_events[1].record(stream=d2h_stream) + + n_diag_blocks: int = L_diagonal_blocks.shape[0] + + if n_diag_blocks > 1: + + L_lower_diagonal_blocks_d[0].set( + arr=L_lower_diagonal_blocks[0], stream=h2d_stream + ) + h2d_lower_diagonal_events[0].record(stream=h2d_stream) + + # --- Computations --- + for i in range(0, n_diag_blocks - 1): + # pass next B block + h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) + + B_d[(i + 1) % 2].set( + arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream=h2d_stream, + ) + + h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) + + if i + 1 < n_diag_blocks - 1: + # pass next diagonal block + h2d_stream.wait_event(compute_current_B_events[(i + 1) % 2]) + L_diagonal_blocks_d[(i + 1) % 2].set( + arr=L_diagonal_blocks[i + 1], stream=h2d_stream + ) + + h2d_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) + + with compute_stream: + # Solve current B block + compute_stream.wait_event(h2d_diagonal_events[i % 2]) + + B_d[i % 2] = cu_la.solve_triangular( + L_diagonal_blocks_d[i % 2], + B_d[i % 2], + lower=True, + ) + + compute_current_B_events[i % 2].record(stream=compute_stream) + + # Pass current B block back + + if i + 1 < n_diag_blocks - 1: + # Pass next lower diagonal block + h2d_stream.wait_event(compute_next_B_events[(i + 1) % 2]) + L_lower_diagonal_blocks_d[(i + 1) % 2].set( + arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream + ) + + h2d_lower_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) + + d2h_stream.wait_event(compute_current_B_events[i % 2]) + d2h_stream.wait_event(h2d_lower_diagonal_events[(i + 1) % 2]) + + B_d[i % 2].get( + out=B[i * diag_blocksize : (i + 1) * diag_blocksize], + stream=d2h_stream, + blocking=False, + ) + + d2h_B_events[i % 2].record(stream=d2h_stream) + + with compute_stream: + # Update next B block + compute_stream.wait_event(h2d_B_events[(i + 1) % 2]) + + B_d[(i + 1) % 2] -= L_lower_diagonal_blocks_d[i % 2] @ B_d[i % 2] + + compute_next_B_events[i % 2].record(stream=compute_stream) + + if i + 1 < n_diag_blocks - 1: + # Pass next lower arrow block + h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) + L_lower_arrow_blocks_d[(i + 1) % 2].set( + arr=L_lower_arrow_blocks[i + 1], stream=h2d_stream + ) + + h2d_arrow_events[(i + 1) % 2].record(stream=h2d_stream) + + with compute_stream: + # Update arrow tip + compute_stream.wait_event(h2d_arrow_events[i % 2]) + + B_arrow_tip_d -= L_lower_arrow_blocks_d[i % 2] @ B_d[i % 2] + + compute_arrow_B_events[i % 2].record(stream=compute_stream) + + # Pass arrow tip back + d2h_stream.wait_event(compute_arrow_B_events[n_diag_blocks % 2]) + + B_arrow_tip_d.get( + out=B[-arrow_blocksize:], + stream=d2h_stream, + blocking=False, + ) + + d2h_tip_events[n_diag_blocks % 2].record(stream=d2h_stream) + + if not partial: + # Pass last blocks + h2d_stream.wait_event(d2h_tip_events[n_diag_blocks % 2]) + + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream + ) + + h2d_diagonal_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_lower_arrow_blocks[-1], stream=h2d_stream + ) + + h2d_arrow_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + + with compute_stream: + # Solve last B block + compute_stream.wait_event(h2d_diagonal_events[(n_diag_blocks - 1) % 2]) + + B_d[(n_diag_blocks - 1) % 2] = cu_la.solve_triangular( + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2], + lower=True, + ) + + compute_partial_events[0].record(stream=compute_stream) + + # Pass last B block back + d2h_stream.wait_event(compute_partial_events[0]) + + B_d[(n_diag_blocks - 1) % 2].get( + out=B[ + (n_diag_blocks - 1) + * diag_blocksize : n_diag_blocks + * diag_blocksize + ], + stream=d2h_stream, + blocking=False, + ) + + d2h_B_events[0].record(stream=d2h_stream) + + with compute_stream: + # Solve arrow tip + compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) + + B_arrow_tip_d -= ( + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2] + @ B_d[(n_diag_blocks - 1) % 2] + ) + B_arrow_tip_d = cu_la.solve_triangular( + L_arrow_tip_block_d, B_arrow_tip_d, lower=True + ) + + compute_partial_events[1].record(stream=compute_stream) + + d2h_stream.wait_event(compute_partial_events[1]) + + B_arrow_tip_d.get( + out=B[-arrow_blocksize:], + stream=d2h_stream, + blocking=False, + ) + + elif trans == "T" or trans == "C": + # ----- Backward substitution ----- + + # Buffers + B_previous_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype) + + # Delete helper variable + del B_shape + + # Events + compute_B_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_events = [cp.cuda.Event(), cp.cuda.Event()] + d2h_events = [cp.cuda.Event(), cp.cuda.Event()] + + # --- H2D: transfers --- + B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) + L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:], stream=h2d_stream) + B_d[(n_diag_blocks - 1) % 2].set( + arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], + stream=h2d_stream, + ) + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_diagonal_blocks[-1], stream=h2d_stream + ) + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_lower_arrow_blocks[-1], stream=h2d_stream + ) + + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + + # ----- Backward substitution ----- + if not partial: + + with compute_stream: + # X_{ndb+1} = L_{ndb+1,ndb+1}^{-T} (Y_{ndb+1}) + compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) + B_arrow_tip_d = cu_la.solve_triangular( + L_arrow_tip_block_d, + B_arrow_tip_d, + lower=True, + trans="C", + ) + + # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) + B_previous_d[(n_diag_blocks - 1) % 2] = cu_la.solve_triangular( + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2] + - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].conj().T + @ B_arrow_tip_d, + lower=True, + trans="C", + ) + + compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) + + # Pass arrow tip back + d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) + + B_arrow_tip_d.get( + out=B[-arrow_blocksize:], + stream=d2h_stream, + blocking=False, + ) + + if n_diag_blocks > 1: + + B_d[n_diag_blocks % 2].set( + arr=B[ + -arrow_blocksize + - (2 * diag_blocksize) : -arrow_blocksize + - diag_blocksize + ], + stream=h2d_stream, + ) + L_diagonal_blocks_d[n_diag_blocks % 2].set( + arr=L_diagonal_blocks[-2], stream=h2d_stream + ) + L_lower_arrow_blocks_d[n_diag_blocks % 2].set( + arr=L_lower_arrow_blocks[-2], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[n_diag_blocks % 2].set( + arr=L_lower_diagonal_blocks[-1], stream=h2d_stream + ) + + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + + for i in range(n_diag_blocks - 2, -1, -1): + + if i > 0: + # Pass new blocks + h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) + + B_d[(i - 1) % 2].set( + arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], + stream=h2d_stream, + ) + L_diagonal_blocks_d[(i - 1) % 2].set( + arr=L_diagonal_blocks[i - 1], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[(i - 1) % 2].set( + arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream + ) + L_lower_arrow_blocks_d[(i - 1) % 2].set( + arr=L_lower_arrow_blocks[i - 1], stream=h2d_stream + ) + + h2d_events[i % 2].record(stream=h2d_stream) + + with compute_stream: + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + compute_stream.wait_event(h2d_events[(i - 1) % 2]) + compute_stream.wait_event(d2h_events[(i - 1) % 2]) + + B_previous_d[i % 2] = cu_la.solve_triangular( + L_diagonal_blocks_d[i % 2], + B_d[i % 2] + - L_lower_diagonal_blocks_d[i % 2].conj().T + @ B_previous_d[(i - 1) % 2] + - L_lower_arrow_blocks_d[i % 2].conj().T @ B_arrow_tip_d, + lower=True, + trans="C", + ) + + compute_B_events[i % 2].record(compute_stream) + + # Pass previous B block back + d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) + + B_previous_d[(i - 1) % 2].get( + out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream=d2h_stream, + blocking=False, + ) + d2h_events[i % 2].record(stream=d2h_stream) + + # Pass last B block back + d2h_stream.wait_event(compute_B_events[0]) + + B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) + + else: + raise ValueError(f"Invalid transpose argument: {trans}.") + + cp.cuda.Device().synchronize() diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 99aebc82..34c8e74e 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -4,6 +4,7 @@ from serinv import ( ArrayLike, _get_module_from_array, + _get_module_from_str, ) @@ -41,8 +42,11 @@ def pobts( else: # Natural arrowhead if device_streaming: - raise NotImplementedError( - "Streaming is not implemented for the natural arrowhead." + _pobts_streaming( + L_diagonal_blocks, + L_lower_diagonal_blocks, + B, + trans, ) else: _pobts( @@ -163,3 +167,222 @@ def _pobts_permuted( ) else: raise ValueError(f"Invalid transpose argument: {trans}.") + + +def _pobts_streaming( + L_diagonal_blocks: ArrayLike, + L_lower_diagonal_blocks: ArrayLike, + B: ArrayLike, + trans: str, +): + arr_module, _ = _get_module_from_array(arr=L_diagonal_blocks) + if arr_module.__name__ != "numpy": + raise TypeError( + "Host<->Device streaming only works when host-arrays are given." + ) + + cp, cu_la = _get_module_from_str(module_str="cupy") + + # Vars + diag_blocksize = L_diagonal_blocks.shape[1] + n_diag_blocks = L_diagonal_blocks.shape[0] + + # Streams + compute_stream = cp.cuda.Stream(non_blocking=True) + h2d_stream = cp.cuda.Stream(non_blocking=True) + d2h_stream = cp.cuda.Stream(non_blocking=True) + + # Device Buffers + # B Buffers + B_shape = B[0:diag_blocksize] + B_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype) + B_previous_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype) + del B_shape + + # L Buffers + L_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_lower_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + + # Events + compute_B_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_events = [cp.cuda.Event(), cp.cuda.Event()] + d2h_events = [cp.cuda.Event(), cp.cuda.Event()] + + if trans == "N": + # ----- Forward substitution ----- + + # --- H2D: transfers --- + B_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) + L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream) + + h2d_events[1].record(stream=h2d_stream) + + if n_diag_blocks > 1: + B_d[1].set(arr=B[diag_blocksize : (2 * diag_blocksize)], stream=h2d_stream) + L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) + L_lower_diagonal_blocks_d[1].set( + arr=L_lower_diagonal_blocks[0], stream=h2d_stream + ) + + h2d_events[0].record(stream=h2d_stream) + + with compute_stream: + # Solve first B block + compute_stream.wait_event(h2d_events[1]) + + B_previous_d[0] = cu_la.solve_triangular( + L_diagonal_blocks_d[0], + B_d[0], + lower=True, + ) + + compute_B_events[0].record(stream=compute_stream) + + for i in range(1, n_diag_blocks): + + if i + 1 < n_diag_blocks: + # Pass next blocks + h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) + + B_d[(i + 1) % 2].set( + arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream=h2d_stream, + ) + L_diagonal_blocks_d[(i + 1) % 2].set( + arr=L_diagonal_blocks[i + 1], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[(i + 1) % 2].set( + arr=L_lower_diagonal_blocks[i], stream=h2d_stream + ) + + h2d_events[i % 2].record(stream=h2d_stream) + + with compute_stream: + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + compute_stream.wait_event(h2d_events[(i + 1) % 2]) + compute_stream.wait_event(d2h_events[(i + 1) % 2]) + + B_previous_d[i % 2] = cu_la.solve_triangular( + L_diagonal_blocks_d[i % 2], + B_d[i % 2] + - L_lower_diagonal_blocks_d[i % 2] @ B_previous_d[(i + 1) % 2], + lower=True, + ) + + compute_B_events[i % 2].record(compute_stream) + + # Pass previous B block back + d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) + + B_previous_d[(i + 1) % 2].get( + out=B[(i - 1) * diag_blocksize : i * diag_blocksize], + stream=d2h_stream, + blocking=False, + ) + + d2h_events[i % 2].record(stream=d2h_stream) + + # Pass last B block back + d2h_stream.wait_event(compute_B_events[(n_diag_blocks + 1) % 2]) + + B_previous_d[(n_diag_blocks + 1) % 2].get( + out=B[-diag_blocksize:], stream=d2h_stream, blocking=False + ) + + elif trans == "T" or trans == "C": + # ----- Backward substitution ----- + + # --- H2D: transfers --- + B_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_diagonal_blocks[-1], stream=h2d_stream + ) + + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + + if n_diag_blocks > 1: + + B_d[n_diag_blocks % 2].set( + arr=B[-(2 * diag_blocksize) : -diag_blocksize], stream=h2d_stream + ) + L_diagonal_blocks_d[n_diag_blocks % 2].set( + arr=L_diagonal_blocks[-2], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[n_diag_blocks % 2].set( + arr=L_lower_diagonal_blocks[-1], stream=h2d_stream + ) + + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + + with compute_stream: + # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) + compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) + + B_previous_d[(n_diag_blocks - 1) % 2] = cu_la.solve_triangular( + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2], + lower=True, + trans="C", + ) + + compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) + + for i in range(n_diag_blocks - 2, -1, -1): + + if i > 0: + # pass next blocks + h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) + + B_d[(i - 1) % 2].set( + arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], + stream=h2d_stream, + ) + L_diagonal_blocks_d[(i - 1) % 2].set( + arr=L_diagonal_blocks[i - 1], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[(i - 1) % 2].set( + arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream + ) + + h2d_events[i % 2].record(stream=h2d_stream) + + with compute_stream: + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + compute_stream.wait_event(h2d_events[(i - 1) % 2]) + compute_stream.wait_event(d2h_events[(i - 1) % 2]) + + B_previous_d[i % 2] = cu_la.solve_triangular( + L_diagonal_blocks_d[i % 2], + B_d[i % 2] + - L_lower_diagonal_blocks_d[i % 2].conj().T + @ B_previous_d[(i - 1) % 2], + lower=True, + trans="C", + ) + + compute_B_events[i % 2].record(compute_stream) + + # Pass previous B block back + d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) + + B_previous_d[(i - 1) % 2].get( + out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream=d2h_stream, + blocking=False, + ) + + d2h_events[i % 2].record(stream=d2h_stream) + + # Pass last B block back + d2h_stream.wait_event(compute_B_events[0]) + + B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) + + else: + raise ValueError(f"Invalid transpose argument: {trans}.") + + cp.cuda.Device().synchronize() diff --git a/src/serinv/wrappers/ddbtars.py b/src/serinv/wrappers/ddbtars.py index 722582a6..17a12993 100644 --- a/src/serinv/wrappers/ddbtars.py +++ b/src/serinv/wrappers/ddbtars.py @@ -15,7 +15,6 @@ import cupyx as cpx import cupy as cp - def allocate_ddbtars( A_diagonal_blocks: ArrayLike, A_lower_diagonal_blocks: ArrayLike, diff --git a/src/serinv/wrappers/pddbtasci.py b/src/serinv/wrappers/pddbtasci.py index f235d76b..3ad34f79 100644 --- a/src/serinv/wrappers/pddbtasci.py +++ b/src/serinv/wrappers/pddbtasci.py @@ -44,7 +44,7 @@ def pddbtasci( The arrow tip block of the block tridiagonal with arrowhead matrix. comm : MPI.Comm The MPI communicator. Default is MPI.COMM_WORLD. - + Keyword Arguments ----------------- rhs : dict diff --git a/src/serinv/wrappers/pddbtsc.py b/src/serinv/wrappers/pddbtsc.py index e9f1eb9e..11d62461 100644 --- a/src/serinv/wrappers/pddbtsc.py +++ b/src/serinv/wrappers/pddbtsc.py @@ -194,4 +194,5 @@ def pddbtsc( comm.Barrier() + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/pddbtsci.py b/src/serinv/wrappers/pddbtsci.py index 63d94c84..18d5ea00 100644 --- a/src/serinv/wrappers/pddbtsci.py +++ b/src/serinv/wrappers/pddbtsci.py @@ -34,7 +34,7 @@ def pddbtsci( The upper diagonal blocks of the block tridiagonal matrix. comm : MPI.Comm The MPI communicator. Default is MPI.COMM_WORLD. - + Keyword Arguments ----------------- rhs : dict diff --git a/tests/conftest.py b/tests/conftest.py index 3e624933..25b716a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. # Global pytest fixtures for the Serinv tests. - import pytest from serinv import backend_flags @@ -15,7 +14,6 @@ ] ) - DTYPE = [ pytest.param("float64", id="float64"), pytest.param("complex128", id="complex128"), diff --git a/tests/tests_algs/regular/tests_bt/test_pobtf.py b/tests/tests_algs/regular/tests_bt/test_pobtf.py index d1969b05..0ac3ae89 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobtf.py +++ b/tests/tests_algs/regular/tests_bt/test_pobtf.py @@ -3,6 +3,8 @@ import numpy as np import pytest +from ....conftest import ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize @@ -11,6 +13,16 @@ if backend_flags["cupy_avail"]: import cupyx as cpx + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + + @pytest.fixture(params=ARRAY_TYPE, autouse=True) + def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() def test_pobtf( diff --git a/tests/tests_algs/regular/tests_bt/test_pobts.py b/tests/tests_algs/regular/tests_bt/test_pobts.py index 8125df52..9011caa1 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobts.py +++ b/tests/tests_algs/regular/tests_bt/test_pobts.py @@ -3,11 +3,26 @@ import numpy as np import pytest -from serinv import _get_module_from_array +from ....conftest import ARRAY_TYPE + +from serinv import backend_flags, _get_module_from_array from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize, rhs from serinv.algs import pobtf, pobts +if backend_flags["cupy_avail"]: + import cupyx as cpx + + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + + @pytest.fixture(params=ARRAY_TYPE, autouse=True) + def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() @pytest.mark.parametrize("n_rhs", [1, 2, 3]) @@ -18,6 +33,7 @@ def test_pobts( array_type: str, dtype: np.dtype, ): + A = dd_bt( diagonal_blocksize, n_diag_blocks, @@ -47,9 +63,22 @@ def test_pobts( _, ) = bt_dense_to_arrays(A, diagonal_blocksize, n_diag_blocks) + if backend_flags["cupy_avail"] and array_type == "streaming": + A_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_diagonal_blocks) + A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks[:, :, :] + A_lower_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_lower_diagonal_blocks) + A_lower_diagonal_blocks_pinned[:, :, :] = A_lower_diagonal_blocks[:, :, :] + B_pinned = cpx.zeros_like_pinned(B) + B_pinned[:, :] = B[:, :] + + A_diagonal_blocks = A_diagonal_blocks_pinned + A_lower_diagonal_blocks = A_lower_diagonal_blocks_pinned + B = B_pinned + pobtf( A_diagonal_blocks, A_lower_diagonal_blocks, + device_streaming=True if array_type == "streaming" else False, ) # Forward solve: Y=L^{-1}B @@ -58,6 +87,7 @@ def test_pobts( A_lower_diagonal_blocks, B, trans="N", + device_streaming=True if array_type == "streaming" else False, ) # Backward solve: X=L^{-T}Y @@ -66,6 +96,7 @@ def test_pobts( A_lower_diagonal_blocks, B, trans="C", + device_streaming=True if array_type == "streaming" else False, ) assert xp.allclose(B, X_ref) diff --git a/tests/tests_algs/regular/tests_bt/test_pobtsi.py b/tests/tests_algs/regular/tests_bt/test_pobtsi.py index 22ec1d3a..5463a7b9 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobtsi.py +++ b/tests/tests_algs/regular/tests_bt/test_pobtsi.py @@ -3,6 +3,8 @@ import numpy as np import pytest +from ....conftest import ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize @@ -11,6 +13,16 @@ if backend_flags["cupy_avail"]: import cupyx as cpx + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + + @pytest.fixture(params=ARRAY_TYPE, autouse=True) + def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() def test_pobtsi( diff --git a/tests/tests_algs/regular/tests_bta/test_pobtaf.py b/tests/tests_algs/regular/tests_bta/test_pobtaf.py index a30b9094..920c6292 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtaf.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtaf.py @@ -3,15 +3,29 @@ import numpy as np import pytest +from ....conftest import ARRAY_TYPE as ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bta_dense_to_arrays, dd_bta, symmetrize from serinv.algs import pobtaf +if backend_flags["cupy_avail"]: + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + if backend_flags["cupy_avail"]: import cupyx as cpx +@pytest.fixture(params=ARRAY_TYPE, autouse=True) +def array_type(request: pytest.FixtureRequest) -> str: + return request.param + + @pytest.mark.mpi_skip() def test_pobtaf( diagonal_blocksize: int, diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index 647a0168..ffc290c2 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -3,11 +3,28 @@ import numpy as np import pytest -from serinv import _get_module_from_array +from ....conftest import ARRAY_TYPE as ARRAY_TYPE + +from serinv import backend_flags, _get_module_from_array from ....testing_utils import bta_dense_to_arrays, dd_bta, symmetrize, rhs from serinv.algs import pobtaf, pobtas +if backend_flags["cupy_avail"]: + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + +if backend_flags["cupy_avail"]: + import cupyx as cpx + + +@pytest.fixture(params=ARRAY_TYPE, autouse=True) +def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() @pytest.mark.parametrize("n_rhs", [1, 2, 3]) @@ -19,6 +36,7 @@ def test_pobtas( array_type: str, dtype: np.dtype, ): + A = dd_bta( diagonal_blocksize, arrowhead_blocksize, @@ -51,11 +69,30 @@ def test_pobtas( A_arrow_tip_block, ) = bta_dense_to_arrays(A, diagonal_blocksize, arrowhead_blocksize, n_diag_blocks) + if backend_flags["cupy_avail"] and array_type == "streaming": + A_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_diagonal_blocks) + A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks[:, :, :] + A_lower_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_lower_diagonal_blocks) + A_lower_diagonal_blocks_pinned[:, :, :] = A_lower_diagonal_blocks[:, :, :] + A_lower_arrow_blocks_pinned = cpx.zeros_like_pinned(A_lower_arrow_blocks) + A_lower_arrow_blocks_pinned[:, :, :] = A_lower_arrow_blocks[:, :, :] + A_arrow_tip_block_pinned = cpx.zeros_like_pinned(A_arrow_tip_block) + A_arrow_tip_block_pinned[:, :] = A_arrow_tip_block[:, :] + B_pinned = cpx.zeros_like_pinned(B) + B_pinned[:, :] = B[:, :] + + A_diagonal_blocks = A_diagonal_blocks_pinned + A_lower_diagonal_blocks = A_lower_diagonal_blocks_pinned + A_lower_arrow_blocks = A_lower_arrow_blocks_pinned + A_arrow_tip_block = A_arrow_tip_block_pinned + B = B_pinned + pobtaf( A_diagonal_blocks, A_lower_diagonal_blocks, A_lower_arrow_blocks, A_arrow_tip_block, + device_streaming=True if array_type == "streaming" else False, ) # Forward solve: Y=L^{-1}B @@ -66,6 +103,7 @@ def test_pobtas( A_arrow_tip_block, B, trans="N", + device_streaming=True if array_type == "streaming" else False, ) # Backward solve: X=L^{-T}Y @@ -76,6 +114,9 @@ def test_pobtas( A_arrow_tip_block, B, trans="C", + device_streaming=True if array_type == "streaming" else False, ) + print("===") + print(X_ref) assert xp.allclose(B, X_ref)