Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
242 commits
Select commit Hold shift + click to select a range
cb018c5
first version of pobtas streaming
03szust Apr 24, 2025
56f1333
change tests incase that broke it
03szust Apr 24, 2025
7babade
test update
03szust Apr 24, 2025
ce19c74
typo
03szust Apr 24, 2025
daae3ab
debug statements
03szust Apr 24, 2025
745b654
debug changes
03szust Apr 24, 2025
5929b08
debug messages
03szust Apr 24, 2025
589d986
print B
03szust Apr 24, 2025
1b0a754
changed B_d shape
03szust Apr 24, 2025
4bea047
changed wrong arrays in streaming
03szust Apr 24, 2025
6d2eacf
debug shapes
03szust Apr 24, 2025
b2cb421
typo
03szust Apr 24, 2025
8c11dcb
compare B and L
03szust Apr 24, 2025
b6fdc79
changed B slice in 1
03szust Apr 24, 2025
ff6bc95
changed actual B slices
03szust Apr 24, 2025
8b475f1
changed further B slice
03szust Apr 24, 2025
df9e172
fixed typos
03szust Apr 24, 2025
fd604ce
changed last B slice
03szust Apr 24, 2025
7e84329
changed index for lower diag blocks
03szust Apr 24, 2025
a42e226
changed index for diagonal blocks
03szust Apr 24, 2025
3a9d5dc
inserted ifs for termination
03szust Apr 24, 2025
15fcc70
fixed typo
03szust Apr 24, 2025
473ff93
fixed typo
03szust Apr 24, 2025
15f1066
insert debug prints
03szust Apr 24, 2025
bbe5dc6
typo
03szust Apr 24, 2025
1ef1b83
changed b last block
03szust Apr 24, 2025
cccae25
fixed lower arrow blocks in partial
03szust Apr 24, 2025
e705217
changed typo
03szust Apr 24, 2025
d413106
changed test
03szust Apr 24, 2025
6a181af
new debug print
03szust Apr 24, 2025
41d7c60
changed logic to accomodate arrow sizes
03szust Apr 25, 2025
43cd682
typo
03szust Apr 25, 2025
c4aba12
fixed function
03szust Apr 25, 2025
e252cf1
insert debug statements
03szust Apr 25, 2025
f46b97a
more debugging
03szust Apr 25, 2025
5e263b6
debugging second to last solve
03szust Apr 25, 2025
2c2b5f1
fixed second to last solve
03szust Apr 25, 2025
d82d1df
removed debugging statements
03szust Apr 25, 2025
4dd8aa5
insert debug statements
03szust Apr 25, 2025
8561032
fixed index typo
03szust Apr 25, 2025
b263c36
changed debug statement
03szust Apr 25, 2025
2531027
changed operation order
03szust Apr 25, 2025
88d76bd
changed to right B
03szust Apr 25, 2025
dfb761b
setup corrected for out of bounds
03szust Apr 25, 2025
592c435
removed debug statements
03szust Apr 25, 2025
ce7a553
insert debug statement
03szust Apr 25, 2025
c5e249c
forced streaming in tests again
03szust Apr 25, 2025
ebe8fa7
force streaming in pobtaf for testing
03szust Apr 25, 2025
979bdcf
removed forced streaming from pobtaf
03szust Apr 25, 2025
c57c8db
changed stream timing
03szust Apr 29, 2025
8fc22ab
added sync
03szust Apr 29, 2025
6c15026
insert debug statements
03szust Apr 29, 2025
a26688e
insert antoher debug statement
03szust Apr 29, 2025
befd8c3
removed misguided overlap protection
03szust Apr 29, 2025
3975f51
changed streaming order
03szust Apr 29, 2025
d1773ae
rolled back if statement
03szust Apr 29, 2025
59c7bd9
debug statement to check if the last block is the problem
03szust Apr 29, 2025
62b3e9f
changed non partial solve
03szust Apr 29, 2025
9b97f8a
debug to see passed tests
03szust Apr 29, 2025
250a4c9
inserted debug statements to compare B
03szust Apr 29, 2025
bfa1e47
changed arrow tip block
03szust Apr 29, 2025
7894956
changed stream timing
03szust Apr 29, 2025
6fc35a5
changed if to stream b + 1
03szust Apr 29, 2025
029c7ce
debug changed to check n
03szust Apr 29, 2025
055bc0e
consitentcy update
03szust Apr 29, 2025
c048273
changed non partial part
03szust Apr 29, 2025
35d4b82
changed non partial block to match indexing
03szust Apr 29, 2025
9b21cc5
first attempt at backward solve
03szust Apr 29, 2025
c00d7a4
fixed typo
03szust Apr 29, 2025
92700a1
another typo
03szust Apr 29, 2025
5bee42c
insert parenthesis
03szust Apr 29, 2025
2917335
insert debug staetments
03szust Apr 29, 2025
232c083
more debug
03szust Apr 29, 2025
191e361
added missing streaming
03szust Apr 29, 2025
996d0cd
added debug statements
03szust Apr 29, 2025
d354f75
changed debug
03szust Apr 29, 2025
3e65913
new debug statements
03szust Apr 29, 2025
ad8a149
new debugs
03szust Apr 29, 2025
beee970
changed stream timing
03szust Apr 29, 2025
ef775f6
adjusted stram timing
03szust Apr 29, 2025
debc23b
changed event recording
03szust Apr 29, 2025
37abb09
more debug
03szust Apr 29, 2025
dd350b9
insert first compare debug
03szust Apr 29, 2025
b223617
second debug compare
03szust Apr 29, 2025
6e69afe
inserted lower diagonal blocks streaming
03szust Apr 29, 2025
a8aa23d
debug compare 3
03szust Apr 29, 2025
bd17add
compare 4
03szust Apr 29, 2025
d13eb0f
changed location of B_previous
03szust Apr 29, 2025
3fdc3e4
added previous B setup
03szust Apr 29, 2025
ae99000
fixed indexing
03szust Apr 29, 2025
24bf36c
moved brevious b from if
03szust Apr 29, 2025
3caa601
moved previous b from correct if
03szust Apr 29, 2025
8cba18e
removed debug statements
03szust Apr 29, 2025
03072c6
moved a wait event
03szust Apr 29, 2025
f6a390b
delayed d2h stream
03szust Apr 29, 2025
26cb9af
adjusted stream timing
03szust Apr 29, 2025
898955c
even more adjusted timing
03szust Apr 29, 2025
8c14ed3
changed streaming order
03szust Apr 30, 2025
4899bb3
removed strange get
03szust Apr 30, 2025
e96994c
insert debug staetments
03szust Apr 30, 2025
184be2d
changed debug
03szust Apr 30, 2025
62b12ab
changed last get
03szust Apr 30, 2025
216664f
more debugging
03szust Apr 30, 2025
a942a65
changed B events
03szust Apr 30, 2025
af23592
print B_d
03szust Apr 30, 2025
e512e8f
insert seperator print
03szust Apr 30, 2025
2b9958c
changed location of previous B event
03szust Apr 30, 2025
22ecb37
changed order of compute stream
03szust Apr 30, 2025
336e831
switched chose previous B
03szust Apr 30, 2025
84270b1
changed wait event
03szust Apr 30, 2025
974e9c2
changed another wait event
03szust Apr 30, 2025
8a17642
changed stream pattern
03szust Apr 30, 2025
8b6ba38
changed previous B
03szust Apr 30, 2025
684a785
removed last B get
03szust Apr 30, 2025
5f2f947
changed indexing
03szust Apr 30, 2025
b09e6d1
changed streaming a bit
03szust Apr 30, 2025
e94c998
insert debug
03szust Apr 30, 2025
bf74bcd
more debug
03szust Apr 30, 2025
280ce36
inser print B
03szust Apr 30, 2025
4445c3c
another print B
03szust Apr 30, 2025
473911b
print xref
03szust Apr 30, 2025
4f43057
more debug
03szust Apr 30, 2025
dbe0abf
another B_d print
03szust Apr 30, 2025
025f062
insert last B d2h
03szust Apr 30, 2025
e052e92
condition last stream
03szust Apr 30, 2025
ee5557f
insert wait event for last stream
03szust Apr 30, 2025
50f009e
backward solve working
03szust Apr 30, 2025
b39139f
bigger tests
03szust Apr 30, 2025
a513fa6
even bigger tests
03szust Apr 30, 2025
1461fb8
reverted tests for now
03szust Apr 30, 2025
1fd1c49
first attempt at adapted code for pobts
03szust May 1, 2025
9b56990
removed not implemented error
03szust May 1, 2025
7a40396
insert debug
03szust May 1, 2025
4bc0e71
fixed array slicing
03szust May 1, 2025
5ff5e64
pobts streaming working
03szust May 1, 2025
4fbf506
first attempt at pobts forward streaming by flipping it
03szust May 1, 2025
08d2f76
added test logic
03szust May 1, 2025
8c2578f
changed indexing
03szust May 1, 2025
4253208
fixed more indexing
03szust May 1, 2025
f103797
switched event order
03szust May 1, 2025
b7cc662
changed first block logic
03szust May 1, 2025
4ecfa60
fixed solve
03szust May 1, 2025
0ab89a1
insert debug statement
03szust May 1, 2025
65b7b18
changed lower diagonal order
03szust May 1, 2025
ab01d34
inser debug message
03szust May 1, 2025
bc3c312
changed slicing
03szust May 1, 2025
52e9004
adjusted loop
03szust May 1, 2025
7ec0eba
adjusted loop
03szust May 1, 2025
62590f4
changed previous B
03szust May 1, 2025
b50b431
insert debug check 1
03szust May 1, 2025
f1e38cb
adjusted streaming
03szust May 1, 2025
6f8a5b5
adjusted streaming
03szust May 1, 2025
72cb3b8
insert more debug
03szust May 1, 2025
0a94ae1
expanded for loop
03szust May 1, 2025
734adfa
adjusted streaming
03szust May 1, 2025
029469b
check number 2
03szust May 1, 2025
8dd0adb
shifted indexing
03szust May 1, 2025
7e19f02
changed lower streaming
03szust May 1, 2025
bd10a8b
more debug
03szust May 1, 2025
50f7b05
removed some debug
03szust May 1, 2025
9b2f8a5
debug number 3
03szust May 1, 2025
ce86a5b
changed B streaming
03szust May 1, 2025
3645026
more changes to B streaming
03szust May 1, 2025
6d53eab
changed B previous
03szust May 1, 2025
66a78d2
removed wrong transposition
03szust May 1, 2025
16050ed
debug check 4
03szust May 1, 2025
a4d50cd
debug b previous
03szust May 1, 2025
82f1840
moved debug message
03szust May 1, 2025
05ae212
shift B previous get
03szust May 1, 2025
8fcb99f
changed last B
03szust May 1, 2025
9402074
test for last B
03szust May 1, 2025
5dd7f1e
revert
03szust May 1, 2025
a67e760
try different stream order
03szust May 1, 2025
af8adbc
insert failsafe
03szust May 1, 2025
1a1e329
more failsafe
03szust May 1, 2025
ae78728
removed unnecessary events
03szust May 1, 2025
f866ad0
stream failsafes
03szust May 1, 2025
0876e0e
more failsafe
03szust May 1, 2025
da1fe74
changed faulty event
03szust May 1, 2025
fd4e68b
changed last stream
03szust May 1, 2025
eb67c78
removed unnecessary events
03szust May 1, 2025
db63cd3
more parity
03szust May 1, 2025
e135eb2
more failsafes
03szust May 1, 2025
01dca7c
cosmetic changes
03szust May 1, 2025
e7d4646
more cosmetic changes
03szust May 1, 2025
1ab8063
attempt to reduce streaming
03szust May 1, 2025
e2971af
reduced streaming
03szust May 1, 2025
50b0cd4
attempt to reduce streaming
03szust May 1, 2025
8e07ceb
parity reduced streaming
03szust May 1, 2025
c6bb7c2
attempt to fuirther reduce streaming
03szust May 1, 2025
4e08701
speed up setup attempt
03szust May 1, 2025
9b3b0eb
expand delay reduction
03szust May 1, 2025
74f17ac
comment changes
03szust May 1, 2025
2fc0dbd
check for useless if
03szust May 1, 2025
bcfc6c8
check for duplicate
03szust May 1, 2025
4af7744
reverted
03szust May 1, 2025
772da8d
reduced for loop
03szust May 1, 2025
ff24ea0
reordered streaming
03szust May 1, 2025
7a40c76
moved streaming and added documentation
03szust May 1, 2025
81c8384
bigger tests
03szust May 2, 2025
7df30e0
even bigger tests
03szust May 2, 2025
d95644e
even more bigger tests
03szust May 2, 2025
694b363
changed tests to be smaller
03szust May 2, 2025
9edc1fa
smaller tests again
03szust May 2, 2025
fae1ecb
reset tests
03szust May 2, 2025
f3852ff
add scripts for cscs
May 8, 2025
2db62af
updarte bash script
03szust May 8, 2025
8f4b19f
removed load_modules
03szust May 9, 2025
c110d9f
changed file path
03szust May 9, 2025
2cd787c
change to enable streaming on daint
03szust May 9, 2025
d559056
added check message
03szust May 15, 2025
2c262b9
changed given arrays
03szust May 15, 2025
10b352c
rolled back block choice for further testing
03szust May 15, 2025
610cf24
attempt to activate streaming
03szust May 15, 2025
134a9ef
typo
03szust May 15, 2025
439ede8
another typo
03szust May 15, 2025
24d9b3a
enable streaming for pobtaf
03szust May 15, 2025
8648d1d
removing copy
03szust May 15, 2025
1e189c7
pinned memory
03szust May 15, 2025
bd2c613
typo
03szust May 15, 2025
af86157
changed block name
03szust May 15, 2025
26adf61
import cupyx
03szust May 15, 2025
4f03cab
missing B_cpu
03szust May 15, 2025
4effd14
changed nvtx
03szust May 15, 2025
044d2d6
moved pop
03szust May 15, 2025
7cd7d08
untangled streaming
03szust May 15, 2025
9d21413
modified tests
03szust May 16, 2025
c2cb681
pytest array_type override
03szust May 16, 2025
89913e8
changed tests a bit to not override
03szust May 16, 2025
2b7a5ca
activate pobtaf streaming in tests
03szust May 16, 2025
afcc2d0
removed nvtx and tests the tests
03szust May 16, 2025
9ed63e9
removed test testing
03szust May 16, 2025
17f87e7
expanded tests
03szust May 16, 2025
6a78a85
expanded tests further
03szust May 16, 2025
2b94651
activated streaming tests for pobtaf
03szust May 16, 2025
6a5f68f
removed leftover cscs scripts
03szust May 16, 2025
cc7f7d0
removed line that forced streaming
03szust May 16, 2025
907dd11
first modification to get cupy and scipy implementations for trsm rig…
03szust May 27, 2025
79b85e5
moved improvement files to new branch
03szust May 27, 2025
e899744
changed errors
03szust Jun 6, 2025
53d5e5a
unified (and added) test streaming for pobtaf/si
vincent-maillou Jun 5, 2025
a361b9c
just ran `black .`
vincent-maillou Jun 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
410 changes: 408 additions & 2 deletions src/serinv/algs/pobtas.py

Large diffs are not rendered by default.

227 changes: 225 additions & 2 deletions src/serinv/algs/pobts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from serinv import (
ArrayLike,
_get_module_from_array,
_get_module_from_str,
)


Expand Down Expand Up @@ -41,8 +42,11 @@ def pobts(
else:
# Natural arrowhead
if device_streaming:
raise NotImplementedError(
"Streaming is not implemented for the natural arrowhead."
_pobts_streaming(
L_diagonal_blocks,
L_lower_diagonal_blocks,
B,
trans,
)
else:
_pobts(
Expand Down Expand Up @@ -163,3 +167,222 @@ def _pobts_permuted(
)
else:
raise ValueError(f"Invalid transpose argument: {trans}.")


def _pobts_streaming(
L_diagonal_blocks: ArrayLike,
L_lower_diagonal_blocks: ArrayLike,
B: ArrayLike,
trans: str,
):
arr_module, _ = _get_module_from_array(arr=L_diagonal_blocks)
if arr_module.__name__ != "numpy":
raise TypeError(
"Host<->Device streaming only works when host-arrays are given."
)

cp, cu_la = _get_module_from_str(module_str="cupy")

# Vars
diag_blocksize = L_diagonal_blocks.shape[1]
n_diag_blocks = L_diagonal_blocks.shape[0]

# Streams
compute_stream = cp.cuda.Stream(non_blocking=True)
h2d_stream = cp.cuda.Stream(non_blocking=True)
d2h_stream = cp.cuda.Stream(non_blocking=True)

# Device Buffers
# B Buffers
B_shape = B[0:diag_blocksize]
B_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype)
B_previous_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype)
del B_shape

# L Buffers
L_diagonal_blocks_d = cp.empty(
(2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype
)
L_lower_diagonal_blocks_d = cp.empty(
(2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype
)

# Events
compute_B_events = [cp.cuda.Event(), cp.cuda.Event()]
h2d_events = [cp.cuda.Event(), cp.cuda.Event()]
d2h_events = [cp.cuda.Event(), cp.cuda.Event()]

if trans == "N":
# ----- Forward substitution -----

# --- H2D: transfers ---
B_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream)
L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream)

h2d_events[1].record(stream=h2d_stream)

if n_diag_blocks > 1:
B_d[1].set(arr=B[diag_blocksize : (2 * diag_blocksize)], stream=h2d_stream)
L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream)
L_lower_diagonal_blocks_d[1].set(
arr=L_lower_diagonal_blocks[0], stream=h2d_stream
)

h2d_events[0].record(stream=h2d_stream)

with compute_stream:
# Solve first B block
compute_stream.wait_event(h2d_events[1])

B_previous_d[0] = cu_la.solve_triangular(
L_diagonal_blocks_d[0],
B_d[0],
lower=True,
)

compute_B_events[0].record(stream=compute_stream)

for i in range(1, n_diag_blocks):

if i + 1 < n_diag_blocks:
# Pass next blocks
h2d_stream.wait_event(compute_B_events[(i + 1) % 2])

B_d[(i + 1) % 2].set(
arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize],
stream=h2d_stream,
)
L_diagonal_blocks_d[(i + 1) % 2].set(
arr=L_diagonal_blocks[i + 1], stream=h2d_stream
)
L_lower_diagonal_blocks_d[(i + 1) % 2].set(
arr=L_lower_diagonal_blocks[i], stream=h2d_stream
)

h2d_events[i % 2].record(stream=h2d_stream)

with compute_stream:
# X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1}
compute_stream.wait_event(h2d_events[(i + 1) % 2])
compute_stream.wait_event(d2h_events[(i + 1) % 2])

B_previous_d[i % 2] = cu_la.solve_triangular(
L_diagonal_blocks_d[i % 2],
B_d[i % 2]
- L_lower_diagonal_blocks_d[i % 2] @ B_previous_d[(i + 1) % 2],
lower=True,
)

compute_B_events[i % 2].record(compute_stream)

# Pass previous B block back
d2h_stream.wait_event(compute_B_events[(i - 1) % 2])

B_previous_d[(i + 1) % 2].get(
out=B[(i - 1) * diag_blocksize : i * diag_blocksize],
stream=d2h_stream,
blocking=False,
)

d2h_events[i % 2].record(stream=d2h_stream)

# Pass last B block back
d2h_stream.wait_event(compute_B_events[(n_diag_blocks + 1) % 2])

B_previous_d[(n_diag_blocks + 1) % 2].get(
out=B[-diag_blocksize:], stream=d2h_stream, blocking=False
)

elif trans == "T" or trans == "C":
# ----- Backward substitution -----

# --- H2D: transfers ---
B_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream)
L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(
arr=L_diagonal_blocks[-1], stream=h2d_stream
)

h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream)

if n_diag_blocks > 1:

B_d[n_diag_blocks % 2].set(
arr=B[-(2 * diag_blocksize) : -diag_blocksize], stream=h2d_stream
)
L_diagonal_blocks_d[n_diag_blocks % 2].set(
arr=L_diagonal_blocks[-2], stream=h2d_stream
)
L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(
arr=L_lower_diagonal_blocks[-1], stream=h2d_stream
)

h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream)

with compute_stream:
# X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1})
compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2])

B_previous_d[(n_diag_blocks - 1) % 2] = cu_la.solve_triangular(
L_diagonal_blocks_d[(n_diag_blocks - 1) % 2],
B_d[(n_diag_blocks - 1) % 2],
lower=True,
trans="C",
)

compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream)

for i in range(n_diag_blocks - 2, -1, -1):

if i > 0:
# pass next blocks
h2d_stream.wait_event(compute_B_events[(i - 1) % 2])

B_d[(i - 1) % 2].set(
arr=B[(i - 1) * diag_blocksize : i * diag_blocksize],
stream=h2d_stream,
)
L_diagonal_blocks_d[(i - 1) % 2].set(
arr=L_diagonal_blocks[i - 1], stream=h2d_stream
)
L_lower_diagonal_blocks_d[(i - 1) % 2].set(
arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream
)

h2d_events[i % 2].record(stream=h2d_stream)

with compute_stream:
# X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1}
compute_stream.wait_event(h2d_events[(i - 1) % 2])
compute_stream.wait_event(d2h_events[(i - 1) % 2])

B_previous_d[i % 2] = cu_la.solve_triangular(
L_diagonal_blocks_d[i % 2],
B_d[i % 2]
- L_lower_diagonal_blocks_d[i % 2].conj().T
@ B_previous_d[(i - 1) % 2],
lower=True,
trans="C",
)

compute_B_events[i % 2].record(compute_stream)

# Pass previous B block back
d2h_stream.wait_event(compute_B_events[(i - 1) % 2])

B_previous_d[(i - 1) % 2].get(
out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize],
stream=d2h_stream,
blocking=False,
)

d2h_events[i % 2].record(stream=d2h_stream)

# Pass last B block back
d2h_stream.wait_event(compute_B_events[0])

B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False)

else:
raise ValueError(f"Invalid transpose argument: {trans}.")

cp.cuda.Device().synchronize()
1 change: 0 additions & 1 deletion src/serinv/wrappers/ddbtars.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import cupyx as cpx
import cupy as cp


def allocate_ddbtars(
A_diagonal_blocks: ArrayLike,
A_lower_diagonal_blocks: ArrayLike,
Expand Down
2 changes: 1 addition & 1 deletion src/serinv/wrappers/pddbtasci.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def pddbtasci(
The arrow tip block of the block tridiagonal with arrowhead matrix.
comm : MPI.Comm
The MPI communicator. Default is MPI.COMM_WORLD.

Keyword Arguments
-----------------
rhs : dict
Expand Down
1 change: 1 addition & 0 deletions src/serinv/wrappers/pddbtsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,5 @@ def pddbtsc(

comm.Barrier()


return elapsed
2 changes: 1 addition & 1 deletion src/serinv/wrappers/pddbtsci.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def pddbtsci(
The upper diagonal blocks of the block tridiagonal matrix.
comm : MPI.Comm
The MPI communicator. Default is MPI.COMM_WORLD.

Keyword Arguments
-----------------
rhs : dict
Expand Down
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright 2023-2025 ETH Zurich. All rights reserved.
# Global pytest fixtures for the Serinv tests.

import pytest

from serinv import backend_flags
Expand All @@ -15,7 +14,6 @@
]
)


DTYPE = [
pytest.param("float64", id="float64"),
pytest.param("complex128", id="complex128"),
Expand Down
12 changes: 12 additions & 0 deletions tests/tests_algs/regular/tests_bt/test_pobtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import pytest

from ....conftest import ARRAY_TYPE

from serinv import backend_flags, _get_module_from_array
from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize

Expand All @@ -11,6 +13,16 @@
if backend_flags["cupy_avail"]:
import cupyx as cpx

ARRAY_TYPE.extend(
[
pytest.param("streaming", id="streaming"),
]
)

@pytest.fixture(params=ARRAY_TYPE, autouse=True)
def array_type(request: pytest.FixtureRequest) -> str:
return request.param


@pytest.mark.mpi_skip()
def test_pobtf(
Expand Down
33 changes: 32 additions & 1 deletion tests/tests_algs/regular/tests_bt/test_pobts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,26 @@
import numpy as np
import pytest

from serinv import _get_module_from_array
from ....conftest import ARRAY_TYPE

from serinv import backend_flags, _get_module_from_array
from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize, rhs

from serinv.algs import pobtf, pobts

if backend_flags["cupy_avail"]:
import cupyx as cpx

ARRAY_TYPE.extend(
[
pytest.param("streaming", id="streaming"),
]
)

@pytest.fixture(params=ARRAY_TYPE, autouse=True)
def array_type(request: pytest.FixtureRequest) -> str:
return request.param


@pytest.mark.mpi_skip()
@pytest.mark.parametrize("n_rhs", [1, 2, 3])
Expand All @@ -18,6 +33,7 @@ def test_pobts(
array_type: str,
dtype: np.dtype,
):

A = dd_bt(
diagonal_blocksize,
n_diag_blocks,
Expand Down Expand Up @@ -47,9 +63,22 @@ def test_pobts(
_,
) = bt_dense_to_arrays(A, diagonal_blocksize, n_diag_blocks)

if backend_flags["cupy_avail"] and array_type == "streaming":
A_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_diagonal_blocks)
A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks[:, :, :]
A_lower_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_lower_diagonal_blocks)
A_lower_diagonal_blocks_pinned[:, :, :] = A_lower_diagonal_blocks[:, :, :]
B_pinned = cpx.zeros_like_pinned(B)
B_pinned[:, :] = B[:, :]

A_diagonal_blocks = A_diagonal_blocks_pinned
A_lower_diagonal_blocks = A_lower_diagonal_blocks_pinned
B = B_pinned

pobtf(
A_diagonal_blocks,
A_lower_diagonal_blocks,
device_streaming=True if array_type == "streaming" else False,
)

# Forward solve: Y=L^{-1}B
Expand All @@ -58,6 +87,7 @@ def test_pobts(
A_lower_diagonal_blocks,
B,
trans="N",
device_streaming=True if array_type == "streaming" else False,
)

# Backward solve: X=L^{-T}Y
Expand All @@ -66,6 +96,7 @@ def test_pobts(
A_lower_diagonal_blocks,
B,
trans="C",
device_streaming=True if array_type == "streaming" else False,
)

assert xp.allclose(B, X_ref)
Loading