Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
244 commits
Select commit Hold shift + click to select a range
30cf672
first version of pobtas streaming
03szust Apr 24, 2025
9e97c5e
change tests incase that broke it
03szust Apr 24, 2025
380a084
test update
03szust Apr 24, 2025
320ebef
typo
03szust Apr 24, 2025
6eeb0f9
debug statements
03szust Apr 24, 2025
137112f
debug changes
03szust Apr 24, 2025
dfab4ab
debug messages
03szust Apr 24, 2025
e068b84
print B
03szust Apr 24, 2025
2981b3e
changed B_d shape
03szust Apr 24, 2025
29c674f
changed wrong arrays in streaming
03szust Apr 24, 2025
3bffe7d
debug shapes
03szust Apr 24, 2025
1e8acad
typo
03szust Apr 24, 2025
ec0a01b
compare B and L
03szust Apr 24, 2025
6a657cc
changed B slice in 1
03szust Apr 24, 2025
c3bb244
changed actual B slices
03szust Apr 24, 2025
11a9cc6
changed further B slice
03szust Apr 24, 2025
935848d
fixed typos
03szust Apr 24, 2025
f527a69
changed last B slice
03szust Apr 24, 2025
69fc9a0
changed index for lower diag blocks
03szust Apr 24, 2025
467ce64
changed index for diagonal blocks
03szust Apr 24, 2025
aa5c893
inserted ifs for termination
03szust Apr 24, 2025
5ab5cf9
fixed typo
03szust Apr 24, 2025
069f355
fixed typo
03szust Apr 24, 2025
e152786
insert debug prints
03szust Apr 24, 2025
f932f67
typo
03szust Apr 24, 2025
730de95
changed b last block
03szust Apr 24, 2025
b61f864
fixed lower arrow blocks in partial
03szust Apr 24, 2025
81a063f
changed typo
03szust Apr 24, 2025
51acd9a
changed test
03szust Apr 24, 2025
66a23f8
new debug print
03szust Apr 24, 2025
7a456a5
changed logic to accomodate arrow sizes
03szust Apr 25, 2025
b7fa179
typo
03szust Apr 25, 2025
1d550cd
fixed function
03szust Apr 25, 2025
8c221e8
insert debug statements
03szust Apr 25, 2025
380340e
more debugging
03szust Apr 25, 2025
d41e4e9
debugging second to last solve
03szust Apr 25, 2025
7f8ae97
fixed second to last solve
03szust Apr 25, 2025
c2b52aa
removed debugging statements
03szust Apr 25, 2025
11f838e
insert debug statements
03szust Apr 25, 2025
133d006
fixed index typo
03szust Apr 25, 2025
2ab3087
changed debug statement
03szust Apr 25, 2025
aeca9e1
changed operation order
03szust Apr 25, 2025
78e8e25
changed to right B
03szust Apr 25, 2025
f8f5b64
setup corrected for out of bounds
03szust Apr 25, 2025
22c829a
removed debug statements
03szust Apr 25, 2025
79d78f0
insert debug statement
03szust Apr 25, 2025
f00a04a
forced streaming in tests again
03szust Apr 25, 2025
cccd82d
force streaming in pobtaf for testing
03szust Apr 25, 2025
416b0aa
removed forced streaming from pobtaf
03szust Apr 25, 2025
e943999
changed stream timing
03szust Apr 29, 2025
8a05718
added sync
03szust Apr 29, 2025
cef552c
insert debug statements
03szust Apr 29, 2025
18a8b8f
insert antoher debug statement
03szust Apr 29, 2025
8a1f9f3
removed misguided overlap protection
03szust Apr 29, 2025
5acc905
changed streaming order
03szust Apr 29, 2025
c45adc9
rolled back if statement
03szust Apr 29, 2025
3751894
debug statement to check if the last block is the problem
03szust Apr 29, 2025
c6f63d1
changed non partial solve
03szust Apr 29, 2025
ee1798c
debug to see passed tests
03szust Apr 29, 2025
c63a5de
inserted debug statements to compare B
03szust Apr 29, 2025
2083e06
changed arrow tip block
03szust Apr 29, 2025
7805321
changed stream timing
03szust Apr 29, 2025
1215712
changed if to stream b + 1
03szust Apr 29, 2025
b509c0d
debug changed to check n
03szust Apr 29, 2025
c0d0c33
consitentcy update
03szust Apr 29, 2025
6c45c15
changed non partial part
03szust Apr 29, 2025
ba9f28e
changed non partial block to match indexing
03szust Apr 29, 2025
90d6a74
first attempt at backward solve
03szust Apr 29, 2025
14a9215
fixed typo
03szust Apr 29, 2025
3ce3b0a
another typo
03szust Apr 29, 2025
99c2e3d
insert parenthesis
03szust Apr 29, 2025
af5f83d
insert debug staetments
03szust Apr 29, 2025
799af09
more debug
03szust Apr 29, 2025
5cc569e
added missing streaming
03szust Apr 29, 2025
fa95f16
added debug statements
03szust Apr 29, 2025
b394b91
changed debug
03szust Apr 29, 2025
28220ad
new debug statements
03szust Apr 29, 2025
7e19033
new debugs
03szust Apr 29, 2025
8136381
changed stream timing
03szust Apr 29, 2025
cfa8307
adjusted stram timing
03szust Apr 29, 2025
304b368
changed event recording
03szust Apr 29, 2025
dd82d4b
more debug
03szust Apr 29, 2025
5370943
insert first compare debug
03szust Apr 29, 2025
3828910
second debug compare
03szust Apr 29, 2025
e36f83d
inserted lower diagonal blocks streaming
03szust Apr 29, 2025
f71a315
debug compare 3
03szust Apr 29, 2025
a2ddd30
compare 4
03szust Apr 29, 2025
5d22f94
changed location of B_previous
03szust Apr 29, 2025
e59fd54
added previous B setup
03szust Apr 29, 2025
50e728d
fixed indexing
03szust Apr 29, 2025
31233d0
moved brevious b from if
03szust Apr 29, 2025
2133c4d
moved previous b from correct if
03szust Apr 29, 2025
533af2a
removed debug statements
03szust Apr 29, 2025
2db7273
moved a wait event
03szust Apr 29, 2025
45b7179
delayed d2h stream
03szust Apr 29, 2025
ab395e4
adjusted stream timing
03szust Apr 29, 2025
d4f0128
even more adjusted timing
03szust Apr 29, 2025
ba2d6ac
changed streaming order
03szust Apr 30, 2025
5efb03a
removed strange get
03szust Apr 30, 2025
66d2f6b
insert debug staetments
03szust Apr 30, 2025
cd2b9c7
changed debug
03szust Apr 30, 2025
9db7858
changed last get
03szust Apr 30, 2025
b37207b
more debugging
03szust Apr 30, 2025
ad6d375
changed B events
03szust Apr 30, 2025
e83d0b8
print B_d
03szust Apr 30, 2025
6fd9ff1
insert seperator print
03szust Apr 30, 2025
3bc6718
changed location of previous B event
03szust Apr 30, 2025
ab3fd2a
changed order of compute stream
03szust Apr 30, 2025
5e72b08
switched chose previous B
03szust Apr 30, 2025
7155c8a
changed wait event
03szust Apr 30, 2025
09f31c3
changed another wait event
03szust Apr 30, 2025
d75b7ff
changed stream pattern
03szust Apr 30, 2025
87fb54b
changed previous B
03szust Apr 30, 2025
464ca75
removed last B get
03szust Apr 30, 2025
ae2e269
changed indexing
03szust Apr 30, 2025
729be57
changed streaming a bit
03szust Apr 30, 2025
f801076
insert debug
03szust Apr 30, 2025
44d8582
more debug
03szust Apr 30, 2025
0c816b0
inser print B
03szust Apr 30, 2025
b0a6473
another print B
03szust Apr 30, 2025
a73f8d2
print xref
03szust Apr 30, 2025
39138d0
more debug
03szust Apr 30, 2025
8b74d46
another B_d print
03szust Apr 30, 2025
43daa68
insert last B d2h
03szust Apr 30, 2025
ac9f3d6
condition last stream
03szust Apr 30, 2025
a74bcbe
insert wait event for last stream
03szust Apr 30, 2025
3e9644b
backward solve working
03szust Apr 30, 2025
7f17c0f
bigger tests
03szust Apr 30, 2025
7f87fc7
even bigger tests
03szust Apr 30, 2025
1616131
reverted tests for now
03szust Apr 30, 2025
8335da7
first attempt at adapted code for pobts
03szust May 1, 2025
46627ec
removed not implemented error
03szust May 1, 2025
43afebe
insert debug
03szust May 1, 2025
1b05487
fixed array slicing
03szust May 1, 2025
34d6577
pobts streaming working
03szust May 1, 2025
42e215e
first attempt at pobts forward streaming by flipping it
03szust May 1, 2025
abe2879
added test logic
03szust May 1, 2025
96652d5
changed indexing
03szust May 1, 2025
e6ce6c4
fixed more indexing
03szust May 1, 2025
ea29b8e
switched event order
03szust May 1, 2025
6fab517
changed first block logic
03szust May 1, 2025
5e87ead
fixed solve
03szust May 1, 2025
82df445
insert debug statement
03szust May 1, 2025
93976bd
changed lower diagonal order
03szust May 1, 2025
72d1b83
inser debug message
03szust May 1, 2025
bf8077e
changed slicing
03szust May 1, 2025
046dffc
adjusted loop
03szust May 1, 2025
f3bc585
adjusted loop
03szust May 1, 2025
dfbf23b
changed previous B
03szust May 1, 2025
c59634a
insert debug check 1
03szust May 1, 2025
5f91706
adjusted streaming
03szust May 1, 2025
adc84f8
adjusted streaming
03szust May 1, 2025
1409c5d
insert more debug
03szust May 1, 2025
6e13d6f
expanded for loop
03szust May 1, 2025
d856785
adjusted streaming
03szust May 1, 2025
f46a64a
check number 2
03szust May 1, 2025
871a3b7
shifted indexing
03szust May 1, 2025
6f4971c
changed lower streaming
03szust May 1, 2025
8307f51
more debug
03szust May 1, 2025
742dcd3
removed some debug
03szust May 1, 2025
c4f3fed
debug number 3
03szust May 1, 2025
620ee3b
changed B streaming
03szust May 1, 2025
82b8190
more changes to B streaming
03szust May 1, 2025
f88aad8
changed B previous
03szust May 1, 2025
9e401ac
removed wrong transposition
03szust May 1, 2025
7c24151
debug check 4
03szust May 1, 2025
eff1bda
debug b previous
03szust May 1, 2025
c021ca0
moved debug message
03szust May 1, 2025
7794122
shift B previous get
03szust May 1, 2025
3e6d3c3
changed last B
03szust May 1, 2025
e5fb88a
test for last B
03szust May 1, 2025
7eeb5c1
revert
03szust May 1, 2025
2afe74b
try different stream order
03szust May 1, 2025
c06b849
insert failsafe
03szust May 1, 2025
01f67d1
more failsafe
03szust May 1, 2025
5be4c6f
removed unnecessary events
03szust May 1, 2025
e653780
stream failsafes
03szust May 1, 2025
93b669a
more failsafe
03szust May 1, 2025
c6fc65f
changed faulty event
03szust May 1, 2025
12ca640
changed last stream
03szust May 1, 2025
5efb288
removed unnecessary events
03szust May 1, 2025
1894334
more parity
03szust May 1, 2025
9db222d
more failsafes
03szust May 1, 2025
db2928d
cosmetic changes
03szust May 1, 2025
47c9f5c
more cosmetic changes
03szust May 1, 2025
3ba9b9f
attempt to reduce streaming
03szust May 1, 2025
3d9e334
reduced streaming
03szust May 1, 2025
8bf1908
attempt to reduce streaming
03szust May 1, 2025
25c9e56
parity reduced streaming
03szust May 1, 2025
83a681c
attempt to fuirther reduce streaming
03szust May 1, 2025
a685906
speed up setup attempt
03szust May 1, 2025
37118fd
expand delay reduction
03szust May 1, 2025
73996a2
comment changes
03szust May 1, 2025
9d3dda0
check for useless if
03szust May 1, 2025
c2427fe
check for duplicate
03szust May 1, 2025
2f159bf
reverted
03szust May 1, 2025
d710e38
reduced for loop
03szust May 1, 2025
e3dc9d3
reordered streaming
03szust May 1, 2025
b8871cc
moved streaming and added documentation
03szust May 1, 2025
69c20ab
bigger tests
03szust May 2, 2025
f257060
even bigger tests
03szust May 2, 2025
0416d8a
even more bigger tests
03szust May 2, 2025
b00d95b
changed tests to be smaller
03szust May 2, 2025
6dc83e6
smaller tests again
03szust May 2, 2025
2645a8f
reset tests
03szust May 2, 2025
7329ec3
add scripts for cscs
May 8, 2025
04cbd77
updarte bash script
03szust May 8, 2025
6a85bc5
removed load_modules
03szust May 9, 2025
d10b436
changed file path
03szust May 9, 2025
1a64890
change to enable streaming on daint
03szust May 9, 2025
d21064f
added check message
03szust May 15, 2025
86ce7c1
changed given arrays
03szust May 15, 2025
942a146
rolled back block choice for further testing
03szust May 15, 2025
dcc85fb
attempt to activate streaming
03szust May 15, 2025
a73b7b9
typo
03szust May 15, 2025
408628e
another typo
03szust May 15, 2025
060fd0b
enable streaming for pobtaf
03szust May 15, 2025
d08b7b3
removing copy
03szust May 15, 2025
ac47799
pinned memory
03szust May 15, 2025
5faecf6
typo
03szust May 15, 2025
e668184
changed block name
03szust May 15, 2025
2a048b0
import cupyx
03szust May 15, 2025
5bc3bcb
missing B_cpu
03szust May 15, 2025
1cb143c
changed nvtx
03szust May 15, 2025
35e5bf9
moved pop
03szust May 15, 2025
01f4b24
untangled streaming
03szust May 15, 2025
c63cd2c
modified tests
03szust May 16, 2025
a3905e2
pytest array_type override
03szust May 16, 2025
6913677
changed tests a bit to not override
03szust May 16, 2025
f734dc6
activate pobtaf streaming in tests
03szust May 16, 2025
5924e1b
removed nvtx and tests the tests
03szust May 16, 2025
c259cca
removed test testing
03szust May 16, 2025
4cf986f
expanded tests
03szust May 16, 2025
632b74c
expanded tests further
03szust May 16, 2025
680d899
activated streaming tests for pobtaf
03szust May 16, 2025
10de2c5
removed leftover cscs scripts
03szust May 16, 2025
7390a6b
removed line that forced streaming
03szust May 16, 2025
c8f89e2
first modification to get cupy and scipy implementations for trsm rig…
03szust May 27, 2025
cbd07cd
moved improvement files to new branch
03szust May 27, 2025
96fb56b
unified (and added) test streaming for pobtaf/si
vincent-maillou Jun 5, 2025
33ba467
just ran `black .`
vincent-maillou Jun 5, 2025
b812838
Merge branch 'main' into integrate_missing_streaming
vincent-maillou Jun 5, 2025
6ff4ed7
changed errors
03szust Jun 6, 2025
9803af4
Merge branch 'integrate_missing_streaming' of https://github.com/vinc…
03szust Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
410 changes: 408 additions & 2 deletions src/serinv/algs/pobtas.py

Large diffs are not rendered by default.

227 changes: 225 additions & 2 deletions src/serinv/algs/pobts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from serinv import (
ArrayLike,
_get_module_from_array,
_get_module_from_str,
)


Expand Down Expand Up @@ -41,8 +42,11 @@
else:
# Natural arrowhead
if device_streaming:
raise NotImplementedError(
"Streaming is not implemented for the natural arrowhead."
_pobts_streaming(

Check warning on line 45 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L45

Added line #L45 was not covered by tests
L_diagonal_blocks,
L_lower_diagonal_blocks,
B,
trans,
)
else:
_pobts(
Expand Down Expand Up @@ -163,3 +167,222 @@
)
else:
raise ValueError(f"Invalid transpose argument: {trans}.")


def _pobts_streaming(
L_diagonal_blocks: ArrayLike,
L_lower_diagonal_blocks: ArrayLike,
B: ArrayLike,
trans: str,
):
arr_module, _ = _get_module_from_array(arr=L_diagonal_blocks)
if arr_module.__name__ != "numpy":
raise TypeError(

Check warning on line 180 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L178-L180

Added lines #L178 - L180 were not covered by tests
"Host<->Device streaming only works when host-arrays are given."
)

cp, cu_la = _get_module_from_str(module_str="cupy")

Check warning on line 184 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L184

Added line #L184 was not covered by tests

# Vars
diag_blocksize = L_diagonal_blocks.shape[1]
n_diag_blocks = L_diagonal_blocks.shape[0]

Check warning on line 188 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L187-L188

Added lines #L187 - L188 were not covered by tests

# Streams
compute_stream = cp.cuda.Stream(non_blocking=True)
h2d_stream = cp.cuda.Stream(non_blocking=True)
d2h_stream = cp.cuda.Stream(non_blocking=True)

Check warning on line 193 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L191-L193

Added lines #L191 - L193 were not covered by tests

# Device Buffers
# B Buffers
B_shape = B[0:diag_blocksize]
B_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype)
B_previous_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype)
del B_shape

Check warning on line 200 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L197-L200

Added lines #L197 - L200 were not covered by tests

# L Buffers
L_diagonal_blocks_d = cp.empty(

Check warning on line 203 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L203

Added line #L203 was not covered by tests
(2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype
)
L_lower_diagonal_blocks_d = cp.empty(

Check warning on line 206 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L206

Added line #L206 was not covered by tests
(2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype
)

# Events
compute_B_events = [cp.cuda.Event(), cp.cuda.Event()]
h2d_events = [cp.cuda.Event(), cp.cuda.Event()]
d2h_events = [cp.cuda.Event(), cp.cuda.Event()]

Check warning on line 213 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L211-L213

Added lines #L211 - L213 were not covered by tests

if trans == "N":

Check warning on line 215 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L215

Added line #L215 was not covered by tests
# ----- Forward substitution -----

# --- H2D: transfers ---
B_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream)
L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream)

Check warning on line 220 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L219-L220

Added lines #L219 - L220 were not covered by tests

h2d_events[1].record(stream=h2d_stream)

Check warning on line 222 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L222

Added line #L222 was not covered by tests

if n_diag_blocks > 1:
B_d[1].set(arr=B[diag_blocksize : (2 * diag_blocksize)], stream=h2d_stream)
L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream)
L_lower_diagonal_blocks_d[1].set(

Check warning on line 227 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L224-L227

Added lines #L224 - L227 were not covered by tests
arr=L_lower_diagonal_blocks[0], stream=h2d_stream
)

h2d_events[0].record(stream=h2d_stream)

Check warning on line 231 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L231

Added line #L231 was not covered by tests

with compute_stream:

Check warning on line 233 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L233

Added line #L233 was not covered by tests
# Solve first B block
compute_stream.wait_event(h2d_events[1])

Check warning on line 235 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L235

Added line #L235 was not covered by tests

B_previous_d[0] = cu_la.solve_triangular(

Check warning on line 237 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L237

Added line #L237 was not covered by tests
L_diagonal_blocks_d[0],
B_d[0],
lower=True,
)

compute_B_events[0].record(stream=compute_stream)

Check warning on line 243 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L243

Added line #L243 was not covered by tests

for i in range(1, n_diag_blocks):

Check warning on line 245 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L245

Added line #L245 was not covered by tests

if i + 1 < n_diag_blocks:

Check warning on line 247 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L247

Added line #L247 was not covered by tests
# Pass next blocks
h2d_stream.wait_event(compute_B_events[(i + 1) % 2])

Check warning on line 249 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L249

Added line #L249 was not covered by tests

B_d[(i + 1) % 2].set(

Check warning on line 251 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L251

Added line #L251 was not covered by tests
arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize],
stream=h2d_stream,
)
L_diagonal_blocks_d[(i + 1) % 2].set(

Check warning on line 255 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L255

Added line #L255 was not covered by tests
arr=L_diagonal_blocks[i + 1], stream=h2d_stream
)
L_lower_diagonal_blocks_d[(i + 1) % 2].set(

Check warning on line 258 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L258

Added line #L258 was not covered by tests
arr=L_lower_diagonal_blocks[i], stream=h2d_stream
)

h2d_events[i % 2].record(stream=h2d_stream)

Check warning on line 262 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L262

Added line #L262 was not covered by tests

with compute_stream:

Check warning on line 264 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L264

Added line #L264 was not covered by tests
# X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1}
compute_stream.wait_event(h2d_events[(i + 1) % 2])
compute_stream.wait_event(d2h_events[(i + 1) % 2])

Check warning on line 267 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L266-L267

Added lines #L266 - L267 were not covered by tests

B_previous_d[i % 2] = cu_la.solve_triangular(

Check warning on line 269 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L269

Added line #L269 was not covered by tests
L_diagonal_blocks_d[i % 2],
B_d[i % 2]
- L_lower_diagonal_blocks_d[i % 2] @ B_previous_d[(i + 1) % 2],
lower=True,
)

compute_B_events[i % 2].record(compute_stream)

Check warning on line 276 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L276

Added line #L276 was not covered by tests

# Pass previous B block back
d2h_stream.wait_event(compute_B_events[(i - 1) % 2])

Check warning on line 279 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L279

Added line #L279 was not covered by tests

B_previous_d[(i + 1) % 2].get(

Check warning on line 281 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L281

Added line #L281 was not covered by tests
out=B[(i - 1) * diag_blocksize : i * diag_blocksize],
stream=d2h_stream,
blocking=False,
)

d2h_events[i % 2].record(stream=d2h_stream)

Check warning on line 287 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L287

Added line #L287 was not covered by tests

# Pass last B block back
d2h_stream.wait_event(compute_B_events[(n_diag_blocks + 1) % 2])

Check warning on line 290 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L290

Added line #L290 was not covered by tests

B_previous_d[(n_diag_blocks + 1) % 2].get(

Check warning on line 292 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L292

Added line #L292 was not covered by tests
out=B[-diag_blocksize:], stream=d2h_stream, blocking=False
)

elif trans == "T" or trans == "C":

Check warning on line 296 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L296

Added line #L296 was not covered by tests
# ----- Backward substitution -----

# --- H2D: transfers ---
B_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream)
L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set(

Check warning on line 301 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L300-L301

Added lines #L300 - L301 were not covered by tests
arr=L_diagonal_blocks[-1], stream=h2d_stream
)

h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream)

Check warning on line 305 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L305

Added line #L305 was not covered by tests

if n_diag_blocks > 1:

Check warning on line 307 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L307

Added line #L307 was not covered by tests

B_d[n_diag_blocks % 2].set(

Check warning on line 309 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L309

Added line #L309 was not covered by tests
arr=B[-(2 * diag_blocksize) : -diag_blocksize], stream=h2d_stream
)
L_diagonal_blocks_d[n_diag_blocks % 2].set(

Check warning on line 312 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L312

Added line #L312 was not covered by tests
arr=L_diagonal_blocks[-2], stream=h2d_stream
)
L_lower_diagonal_blocks_d[n_diag_blocks % 2].set(

Check warning on line 315 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L315

Added line #L315 was not covered by tests
arr=L_lower_diagonal_blocks[-1], stream=h2d_stream
)

h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream)

Check warning on line 319 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L319

Added line #L319 was not covered by tests

with compute_stream:

Check warning on line 321 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L321

Added line #L321 was not covered by tests
# X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1})
compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2])

Check warning on line 323 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L323

Added line #L323 was not covered by tests

B_previous_d[(n_diag_blocks - 1) % 2] = cu_la.solve_triangular(

Check warning on line 325 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L325

Added line #L325 was not covered by tests
L_diagonal_blocks_d[(n_diag_blocks - 1) % 2],
B_d[(n_diag_blocks - 1) % 2],
lower=True,
trans="C",
)

compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream)

Check warning on line 332 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L332

Added line #L332 was not covered by tests

for i in range(n_diag_blocks - 2, -1, -1):

Check warning on line 334 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L334

Added line #L334 was not covered by tests

if i > 0:

Check warning on line 336 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L336

Added line #L336 was not covered by tests
# pass next blocks
h2d_stream.wait_event(compute_B_events[(i - 1) % 2])

Check warning on line 338 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L338

Added line #L338 was not covered by tests

B_d[(i - 1) % 2].set(

Check warning on line 340 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L340

Added line #L340 was not covered by tests
arr=B[(i - 1) * diag_blocksize : i * diag_blocksize],
stream=h2d_stream,
)
L_diagonal_blocks_d[(i - 1) % 2].set(

Check warning on line 344 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L344

Added line #L344 was not covered by tests
arr=L_diagonal_blocks[i - 1], stream=h2d_stream
)
L_lower_diagonal_blocks_d[(i - 1) % 2].set(

Check warning on line 347 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L347

Added line #L347 was not covered by tests
arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream
)

h2d_events[i % 2].record(stream=h2d_stream)

Check warning on line 351 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L351

Added line #L351 was not covered by tests

with compute_stream:

Check warning on line 353 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L353

Added line #L353 was not covered by tests
# X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1}
compute_stream.wait_event(h2d_events[(i - 1) % 2])
compute_stream.wait_event(d2h_events[(i - 1) % 2])

Check warning on line 356 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L355-L356

Added lines #L355 - L356 were not covered by tests

B_previous_d[i % 2] = cu_la.solve_triangular(

Check warning on line 358 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L358

Added line #L358 was not covered by tests
L_diagonal_blocks_d[i % 2],
B_d[i % 2]
- L_lower_diagonal_blocks_d[i % 2].conj().T
@ B_previous_d[(i - 1) % 2],
lower=True,
trans="C",
)

compute_B_events[i % 2].record(compute_stream)

Check warning on line 367 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L367

Added line #L367 was not covered by tests

# Pass previous B block back
d2h_stream.wait_event(compute_B_events[(i - 1) % 2])

Check warning on line 370 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L370

Added line #L370 was not covered by tests

B_previous_d[(i - 1) % 2].get(

Check warning on line 372 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L372

Added line #L372 was not covered by tests
out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize],
stream=d2h_stream,
blocking=False,
)

d2h_events[i % 2].record(stream=d2h_stream)

Check warning on line 378 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L378

Added line #L378 was not covered by tests

# Pass last B block back
d2h_stream.wait_event(compute_B_events[0])

Check warning on line 381 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L381

Added line #L381 was not covered by tests

B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False)

Check warning on line 383 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L383

Added line #L383 was not covered by tests

else:
raise ValueError(f"Invalid transpose argument: {trans}.")

Check warning on line 386 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L386

Added line #L386 was not covered by tests

cp.cuda.Device().synchronize()

Check warning on line 388 in src/serinv/algs/pobts.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/algs/pobts.py#L388

Added line #L388 was not covered by tests
1 change: 0 additions & 1 deletion src/serinv/wrappers/ddbtars.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import cupyx as cpx
import cupy as cp


def allocate_ddbtars(
A_diagonal_blocks: ArrayLike,
A_lower_diagonal_blocks: ArrayLike,
Expand Down
2 changes: 1 addition & 1 deletion src/serinv/wrappers/pddbtasci.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def pddbtasci(
The arrow tip block of the block tridiagonal with arrowhead matrix.
comm : MPI.Comm
The MPI communicator. Default is MPI.COMM_WORLD.

Keyword Arguments
-----------------
rhs : dict
Expand Down
4 changes: 1 addition & 3 deletions src/serinv/wrappers/pddbtsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,4 @@
quadratic=quadratic,
)

comm.Barrier()

return elapsed
comm.Barrier()

Check warning on line 195 in src/serinv/wrappers/pddbtsc.py

View check run for this annotation

Codecov / codecov/patch

src/serinv/wrappers/pddbtsc.py#L195

Added line #L195 was not covered by tests
2 changes: 1 addition & 1 deletion src/serinv/wrappers/pddbtsci.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def pddbtsci(
The upper diagonal blocks of the block tridiagonal matrix.
comm : MPI.Comm
The MPI communicator. Default is MPI.COMM_WORLD.

Keyword Arguments
-----------------
rhs : dict
Expand Down
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright 2023-2025 ETH Zurich. All rights reserved.
# Global pytest fixtures for the Serinv tests.

import pytest

from serinv import backend_flags
Expand All @@ -15,7 +14,6 @@
]
)


DTYPE = [
pytest.param("float64", id="float64"),
pytest.param("complex128", id="complex128"),
Expand Down
12 changes: 12 additions & 0 deletions tests/tests_algs/regular/tests_bt/test_pobtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import pytest

from ....conftest import ARRAY_TYPE

from serinv import backend_flags, _get_module_from_array
from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize

Expand All @@ -11,6 +13,16 @@
if backend_flags["cupy_avail"]:
import cupyx as cpx

ARRAY_TYPE.extend(
[
pytest.param("streaming", id="streaming"),
]
)

@pytest.fixture(params=ARRAY_TYPE, autouse=True)
def array_type(request: pytest.FixtureRequest) -> str:
return request.param


@pytest.mark.mpi_skip()
def test_pobtf(
Expand Down
33 changes: 32 additions & 1 deletion tests/tests_algs/regular/tests_bt/test_pobts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,26 @@
import numpy as np
import pytest

from serinv import _get_module_from_array
from ....conftest import ARRAY_TYPE

from serinv import backend_flags, _get_module_from_array
from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize, rhs

from serinv.algs import pobtf, pobts

if backend_flags["cupy_avail"]:
import cupyx as cpx

ARRAY_TYPE.extend(
[
pytest.param("streaming", id="streaming"),
]
)

@pytest.fixture(params=ARRAY_TYPE, autouse=True)
def array_type(request: pytest.FixtureRequest) -> str:
return request.param


@pytest.mark.mpi_skip()
@pytest.mark.parametrize("n_rhs", [1, 2, 3])
Expand All @@ -18,6 +33,7 @@ def test_pobts(
array_type: str,
dtype: np.dtype,
):

A = dd_bt(
diagonal_blocksize,
n_diag_blocks,
Expand Down Expand Up @@ -47,9 +63,22 @@ def test_pobts(
_,
) = bt_dense_to_arrays(A, diagonal_blocksize, n_diag_blocks)

if backend_flags["cupy_avail"] and array_type == "streaming":
A_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_diagonal_blocks)
A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks[:, :, :]
A_lower_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_lower_diagonal_blocks)
A_lower_diagonal_blocks_pinned[:, :, :] = A_lower_diagonal_blocks[:, :, :]
B_pinned = cpx.zeros_like_pinned(B)
B_pinned[:, :] = B[:, :]

A_diagonal_blocks = A_diagonal_blocks_pinned
A_lower_diagonal_blocks = A_lower_diagonal_blocks_pinned
B = B_pinned

pobtf(
A_diagonal_blocks,
A_lower_diagonal_blocks,
device_streaming=True if array_type == "streaming" else False,
)

# Forward solve: Y=L^{-1}B
Expand All @@ -58,6 +87,7 @@ def test_pobts(
A_lower_diagonal_blocks,
B,
trans="N",
device_streaming=True if array_type == "streaming" else False,
)

# Backward solve: X=L^{-T}Y
Expand All @@ -66,6 +96,7 @@ def test_pobts(
A_lower_diagonal_blocks,
B,
trans="C",
device_streaming=True if array_type == "streaming" else False,
)

assert xp.allclose(B, X_ref)
Loading