From f7816d854e49b24682942d5291c11b08d6a76409 Mon Sep 17 00:00:00 2001 From: GeoffNN Date: Thu, 18 Jun 2020 15:47:43 +0200 Subject: [PATCH] Parallelization of fast_csr_mv --- copt/utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/copt/utils.py b/copt/utils.py index 1115f624..06d7a22b 100644 --- a/copt/utils.py +++ b/copt/utils.py @@ -112,7 +112,7 @@ def fast_csr_vm(x, data, indptr, indices, d, idx): return res -@njit(nogil=True) +@njit(parallel=True) def fast_csr_mv(data, indptr, indices, x, idx): """ Returns the matrix vector product M[idx] * x. M is described @@ -126,10 +126,13 @@ def fast_csr_mv(data, indptr, indices, x, idx): """ res = np.zeros(len(idx)) - for i, row_idx in np.ndenumerate(idx): - for k, j in enumerate(range(indptr[row_idx], indptr[row_idx+1])): + for i in prange(len(idx)): + row_idx = idx[i] + res_i = 0.0 + for j in range(indptr[row_idx], indptr[row_idx+1]): j_idx = indices[j] - res[i] += x[j_idx] * data[j] + res_i += x[j_idx] * data[j] + res[i] = res_i return res