-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathkernel_babai.py
More file actions
330 lines (286 loc) · 11 KB
/
kernel_babai.py
File metadata and controls
330 lines (286 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import cupy as cp
import numpy as np
import math
from functools import cache
_vals64_src = r"""
extern "C"{
__global__ void gen_values_kernel_f64(const double* __restrict__ vals,
const int L, const int d,
const unsigned long long start_rank,
const int count,
double* __restrict__ out, const int ld)
{
int t = blockDim.x * blockIdx.x + threadIdx.x;
if (t >= count) return;
unsigned long long idx = start_rank + (unsigned long long)t;
// colonne t, lignes 0..d-1 (Fortran: out[t*ld + row])
for (int p = d - 1; p >= 0; --p){
unsigned long long q = idx / (unsigned long long)L;
unsigned int rem = (unsigned int)(idx - q * (unsigned long long)L);
out[(size_t)t * ld + p] = vals[rem];
idx = q;
}
}
}"""
_mod_vals64 = cp.RawModule(code=_vals64_src)
_gen_vals64 = _mod_vals64.get_function("gen_values_kernel_f64")
_src = r"""
extern "C"{
// values-product: enumerate base-L digits -> V (d, B) in Fortran layout
__global__ void gen_values_kernel(const float* __restrict__ vals,
const int L, const int d,
const unsigned long long start_rank,
const int count,
float* __restrict__ out, const int ld)
{
int t = blockDim.x * blockIdx.x + threadIdx.x;
if (t >= count) return;
unsigned long long idx = start_rank + (unsigned long long)t;
// write column t, rows 0..d-1 (Fortran: out[t*ld + row])
// compute most-significant first to match itertools.product order
for (int p = d - 1; p >= 0; --p){
unsigned long long q = idx / (unsigned long long)L;
unsigned int rem = (unsigned int)(idx - q * (unsigned long long)L);
out[(size_t)t * ld + p] = vals[rem];
idx = q;
}
}
__device__ __forceinline__ unsigned long long
C_at(const unsigned long long* __restrict__ choose, int row, int col, int jdim) {
#if __CUDA_ARCH__ >= 350
return __ldg(&choose[(size_t)row * jdim + col]); // read-only cache
#else
return choose[(size_t)row * jdim + col];
#endif
}
// Lexicographic unranking with binary search per coordinate.
// choose shape: (n+1, jdim), row-major: choose[row * jdim + col]
__global__ void unrank_combinations_lex_bs(const unsigned long long* __restrict__ choose,
const int n, const int k,
const unsigned long long start_rank,
const int count,
int* __restrict__ out,
const int jdim)
{
// grid-stride loop: support very large 'count'
for (int t = blockIdx.x * blockDim.x + threadIdx.x; t < count;
t += blockDim.x * gridDim.x)
{
unsigned long long s = start_rank + (unsigned long long)t; // rank for this thread
int a = 0; // minimal value for current coordinate
const int out_base = t * k; // row-major output
// Build combo in increasing order (k entries)
// j goes from k down to 1 (same as ton code)
for (int j = k; j >= 1; --j) {
const int max_x = n - j; // last admissible x (need room for j items)
// T = C(n - a, j), thr = T - s
const int row_na = n - a;
const unsigned long long T = C_at(choose, row_na, j, jdim);
const unsigned long long thr = T - s; // > 0 (s < T automatiquement)
// find the **largest** x in [a, max_x] with C(n - x, j) >= thr
int lo = a, hi = max_x, ans = a; // invariant: ans is last true
while (lo <= hi) {
const int mid = (lo + hi) >> 1;
const unsigned long long c = C_at(choose, n - mid, j, jdim);
if (c >= thr) { ans = mid; lo = mid + 1; }
else { hi = mid - 1; }
}
const int x = ans;
// write output (row-major). Option: column-major for coalesced writes (voir notes)
out[out_base + (k - j)] = x;
// update rank remainder and next lower bound
const unsigned long long c_x = C_at(choose, n - x, j, jdim);
// s <- s - (T - C(n - x, j)) (reste dans [0, C(n - x, j) - 1])
s -= (T - c_x);
a = x + 1;
}
}
}
} // extern "C"
""".strip()
_mod = cp.RawModule(code=_src, options=("-std=c++11",))
_kernel_vals = _mod.get_function("gen_values_kernel")
_kernel_comb = _mod.get_function("unrank_combinations_lex_bs")
_TPB = 256
# ===== helpers =====
def _build_choose_table_dev(n: int, k: int):
"""
choose_dev[u, j] = C(u, j) for u in [0..n], j in [0..k-1]
(on a besoin de j jusqu'à k-1 pour l'unranking lex)
"""
C = np.zeros((n + 1, k), dtype=np.uint64)
C[:, 0] = 1
for u in range(1, n + 1):
up_to = min(u, k - 1)
for j in range(1, up_to + 1):
C[u, j] = C[u - 1, j] + C[u - 1, j - 1]
return cp.asarray(C) # upload une seule fois
# ===== GPU batchers =====
def value_batches_fp32_gpu(values, d: int, batch_size: int):
"""
Compute on GPU (no H2D overhead for send the batch) blocks V_gpu,
it's equal (but here CPU bounded) to list(islice(product(values, repeat=d), ...)).T
values: 1D array-like
"""
vals_dev = (
values
if isinstance(values, cp.ndarray)
else cp.asarray(values, dtype=cp.float32)
)
L = int(vals_dev.size)
total = L**d
assert total < (1 << 64), "L**d trop grand pour uint64."
start = 0
while start < total:
B = min(batch_size, total - start)
V_gpu = cp.empty((d, B), dtype=cp.float32, order="F")
grid = ((B + _TPB - 1) // _TPB,)
_kernel_vals(
grid,
(_TPB,),
(
vals_dev,
cp.int32(L),
cp.int32(d),
cp.uint64(start),
cp.int32(B),
V_gpu,
cp.int32(d),
),
)
yield V_gpu
start += B
def guess_batches_gpu(r: int, d: int, batch_size: int, choose_dev: cp.ndarray = None):
"""
Generate (G, d) on the GPU in lexicographic order.
choose_dev: optional C(u, j) table (if None, it is built and stored on the device).
"""
choose = choose_dev if choose_dev is not None else _build_choose_table_dev(r, d)
total = math.comb(r, d)
assert total < (1 << 64), "C(r, d) too large for uint64."
start = 0
while start < total:
G = min(batch_size, total - start)
idxs_gpu = cp.empty((G, d), dtype=cp.int32) # C-order
grid = ((G + _TPB - 1) // _TPB,)
_kernel_comb(
grid,
(_TPB,),
(
choose,
cp.int32(r),
cp.int32(d),
cp.uint64(start),
cp.int32(G),
idxs_gpu,
cp.int32(choose.shape[1]),
),
)
yield idxs_gpu
start += G
@cache
def __reduction_ranges(n):
"""
Return list of ranges that needs to be reduced.
More generally, it returns, without using recursion, the list that would be
the output of the following Python program:
<<<BEGIN CODE>>>
def rec_range(n):
bc, res = [], []
def F(l, r):
if l == r:
return
if l + 1 == r:
bc.append(l)
else:
m = (l + r) // 2
F(l, m)
F(m, r)
res.append((l, m, r))
return F(0, n)
<<<END CODE>>>
:param n: the length of the array that requires reduction
:return: pair containing `the base_cases` and `result`.
`base_cases` is a list of indices `i` such that:
`i + 1` needs to be reduced w.r.t. `i`.
`result` is a list of triples `(i, j, k)` such that:
`[j:k)` needs to be reduced w.r.t. `[i:j)`.
The guarantee is that for any 0 <= i < j < n:
1) `i in base_cases && j = i + 1`,
OR
2) there is a triple (u, v, w) such that `i in [u, v)` and `j in [v, w)`.
"""
bit_shift, parts, result, base_cases = 1, 1, [], []
while parts < n:
left_bound, left_idx = 0, 0
for i in range(1, parts + 1):
right_bound = left_bound + 2 * n
mid_idx = (left_bound + n) >> bit_shift
right_idx = right_bound >> bit_shift
if right_idx > left_idx + 1:
# Only consider nontrivial intervals
if right_idx == left_idx + 2:
# Return length 2 intervals separately to unroll base case.
base_cases.append(left_idx)
else:
# Properly sized interval:
result.append((left_idx, mid_idx, right_idx))
left_bound, left_idx = right_bound, right_idx
parts *= 2
bit_shift += 1
return base_cases, list(reversed(result))
@cache
def __babai_ranges(n):
# Assume all indices are base cases initially
range_around = [False] * n
for i, j, k in __reduction_ranges(n)[1]:
# Mark node `j` as responsible to reduce [i, j) wrt [j, k) once Babai is at/past index j.
range_around[j] = (i, k)
return range_around
babai_reduce_step = cp.ElementwiseKernel(
in_params="T t, T invd, T d",
out_params="T uo, T tout",
operation=r"""
T u = - nearbyint(t * invd);
uo = u;
tout = fma(d, u, t);
""",
name="babai_reduce_step",
)
# AXPY scalaire en FMA: row += alpha * vec
axpy_scalar_row = cp.ElementwiseKernel(
in_params="T row_in, T vec, T alpha",
out_params="T row_out",
operation=r"""
row_out = fma(alpha, vec, row_in);
""",
name="axpy_scalar_row",
)
def nearest_plane_gpu(R, T, U, range_around, diag, inv_diag):
"""
In-place Babai nearest plane on GPU.
R: (n,n) upper-triangular, cupy ndarray (float32/float64), Fortran order preferred
T: (n,N) targets, cupy ndarray, Fortran order preferred
U: (n,N) integer coeffs (same dtype as T is ok, we rint then cast), Fortran order preferred
range_around: precomputed index ranges like your __babai_ranges(n)
either False or a tuple (i, k) for each j
Side-effects: updates T <- T + R @ U and fills U.
"""
Rm = R
Tm = T
Um = U
n, N = Tm.shape
for j in range(n - 1, -1, -1):
u_j, new_Tj = babai_reduce_step(Tm[j, :], inv_diag[j], diag[j])
Um[j, :] = u_j
Tm[j, :] = new_Tj
ra = range_around[j]
if ra:
i, k = ra
R12 = Rm[i:j, j:k]
U2 = Um[j:k, :]
Tm[i:j, :] += R12 @ U2
else:
if j > 0:
alpha = Rm[j - 1, j]
Tm[j - 1, :] = axpy_scalar_row(Tm[j - 1, :], Um[j, :], alpha)