Compute the binary matrix-vector product y = M · v, where:
Mis ann × nbinary matrix (entries are 0 or 1)vis a real-valued vector of lengthnyis a real-valued output vector of lengthn
The naive approach costs O(n²) multiply-adds. RSR exploits the fact that binary matrices (especially in quantized neural networks) often have many duplicate columns. Columns that are identical contribute identically to each output row, so their input vector elements can be summed first and multiplied only once — reducing total work proportional to the number of unique column patterns.
For a binary matrix, a set of columns with the same bit pattern contributes to the output identically: each of those columns adds the same set of rows. If columns j1, j2, j3 all have the pattern [1, 0, 1]ᵀ, then their joint contribution to the output is:
y[0] += v[j1] + v[j2] + v[j3]
y[2] += v[j1] + v[j2] + v[j3]
Instead of three separate additions per output row, RSR computes the scalar sum s = v[j1] + v[j2] + v[j3] once and then distributes it to rows 0 and 2. This is the Aggregate then Scatter pattern.
Processing all n rows at once would require encoding each column as an n-bit integer — impractical for large n. RSR instead splits M horizontally into blocks of k rows each.
Parameters:
n: matrix dimension (must be square)k: block height (number of rows per block);nmust be divisible byknum_blocks = n / k
Each block is processed independently. Block b covers rows [b·k, (b+1)·k) of M. Columns are encoded as k-bit integers within each block, so the maximum number of unique patterns per block is 2^k.
The outputs from all blocks are concatenated to produce the final n-dimensional result.
Done once when the matrix M is loaded. The result is a set of data structures that accelerate all future inferences.
For each block b (rows b·k to (b+1)·k - 1):
Extract the sub-matrix block = M[b·k : (b+1)·k, :] of shape (k, n).
Encode each column as a k-bit integer where row 0 is the most significant bit (MSB):
col_value[j] = sum over row i of: M[b·k + i, j] * 2^(k-1-i)
Example with k=3: column [1, 0, 1]ᵀ → 1·4 + 0·2 + 1·1 = 5.
Using dot product notation:
bit_weights = [2^(k-1), 2^(k-2), ..., 2^1, 2^0] # shape (k,)
col_values = bit_weights @ block # shape (n,)
Compute a permutation perm that sorts columns by col_values:
perm = argsort(col_values, stable=True) # shape (n,)
After applying perm, columns with the same bit pattern are grouped together.
Apply perm to col_values to get sorted_values, then find unique values:
sorted_values = col_values[perm]
uniq, inverse = unique(sorted_values)
uniq: array of unique k-bit integer patterns, shape(num_unique,)inverse: for each position insorted_values, its index intouniq, shape(n,)
The inverse array (called group_indices) maps each sorted column position to its pattern group.
Group boundaries: since columns are sorted, all columns with pattern group g occupy a contiguous range in perm. Compute cumulative counts to find these boundaries:
counts[g] = number of columns with group index g
group_ends[g] = sum of counts[0..g] # cumulative, i.e. group g covers [group_ends[g-1], group_ends[g])
For each unique pattern integer pat, find which bit positions are set — these are the output row indices to update during scatter:
for each unique pattern pat:
scatter_rows_for_pat = [k-1-bit for each set bit in pat]
The mapping from bit position to row index: bit position b (0 = LSB) in the integer corresponds to row index k-1-b in the block (because row 0 is MSB).
Store a flat array scatter_rows and an offset array scatter_offsets so that pattern g's output rows are scatter_rows[scatter_offsets[g] : scatter_offsets[g+1]].
| Name | Shape | Type | Description |
|---|---|---|---|
perms[b] |
(n,) |
int | Column permutation sorting by pattern |
group_ends[b] |
(num_unique_b,) |
int | Cumulative group end indices into perm |
scatter_offsets[b] |
(num_unique_b + 1,) |
int | Offsets into scatter_rows per group |
scatter_rows[b] |
(total_set_bits_b,) |
int | Output row indices (within block) for each group |
Given preprocessed data and input vector v of shape (n,):
For each block b, reorder v according to perms[b]:
v_perm[b] = v[perms[b]] # shape (n,)
After permutation, elements that belong to the same column group are contiguous.
For each block b and each group g:
start = group_ends[g-1] (or 0 if g == 0)
end = group_ends[g]
aggregated[g] = sum(v_perm[b][start : end])
This sums all input vector elements corresponding to columns with the same pattern.
For each block b, initialize output slice out[b·k : (b+1)·k] = 0.
For each group g:
s_begin = scatter_offsets[g]
s_end = scatter_offsets[g + 1]
for s in range(s_begin, s_end):
out[b·k + scatter_rows[s]] += aggregated[g]
This distributes the aggregated sum to every output row where the group's pattern has a 1-bit.
Concatenate results from all blocks: y = out[0:n].
def preprocess(M, k):
n = M.shape[0]
assert n % k == 0
num_blocks = n // k
perms = []
group_ends_list = []
scatter_offsets_list = []
scatter_rows_list = []
for b in range(num_blocks):
block = M[b*k : (b+1)*k, :] # shape (k, n)
# Step 1.1: encode columns
bit_weights = [2**(k-1-i) for i in range(k)]
col_values = [sum(bit_weights[i] * block[i, j] for i in range(k))
for j in range(n)] # length n
# Step 1.2: sort columns
perm = argsort(col_values, stable=True)
# Step 1.3: find groups
sorted_values = [col_values[perm[j]] for j in range(n)]
uniq = sorted(set(sorted_values))
pat_to_idx = {pat: idx for idx, pat in enumerate(uniq)}
num_unique = len(uniq)
counts = [0] * num_unique
for j in range(n):
counts[pat_to_idx[sorted_values[j]]] += 1
group_ends = []
cum = 0
for g in range(num_unique):
cum += counts[g]
group_ends.append(cum)
# Step 1.4: decode patterns
scatter_offsets = [0]
scatter_rows = []
for pat in uniq:
rows = []
for bit_pos in range(k): # bit_pos 0 = LSB of integer
if pat & (1 << bit_pos):
rows.append(k - 1 - bit_pos) # row 0 = MSB = highest bit_pos
scatter_rows.extend(rows)
scatter_offsets.append(len(scatter_rows))
perms.append(perm)
group_ends_list.append(group_ends)
scatter_offsets_list.append(scatter_offsets)
scatter_rows_list.append(scatter_rows)
return perms, group_ends_list, scatter_offsets_list, scatter_rows_listdef infer(v, perms, group_ends_list, scatter_offsets_list, scatter_rows_list, n, k):
num_blocks = n // k
out = [0.0] * n
for b in range(num_blocks):
perm = perms[b]
group_ends = group_ends_list[b]
scatter_offsets = scatter_offsets_list[b]
scatter_rows = scatter_rows_list[b]
num_unique = len(group_ends)
# Step 2.1: permute input
v_perm = [v[perm[j]] for j in range(n)]
# Step 2.2: aggregate
aggregated = []
start = 0
for g in range(num_unique):
end = group_ends[g]
agg = sum(v_perm[start:end])
aggregated.append(agg)
start = end
# Step 2.3: scatter
for g in range(num_unique):
s_begin = scatter_offsets[g]
s_end = scatter_offsets[g + 1]
for s in range(s_begin, s_end):
out[b*k + scatter_rows[s]] += aggregated[g]
return outn=4, k=2, matrix:
M = [[1, 0, 1, 0],
[0, 1, 1, 0],
[1, 1, 0, 0],
[0, 0, 1, 1]]
num_blocks = 2. Input vector v = [1, 2, 3, 4].
Block 0 (rows 0–1):
block = [[1, 0, 1, 0],
[0, 1, 1, 0]]
bit_weights = [2, 1]
col_values = [2*1+1*0, 2*0+1*1, 2*1+1*1, 2*0+1*0] = [2, 1, 3, 0]
perm = argsort([2,1,3,0]) = [3, 1, 0, 2] (indices sorted by value)
sorted_values = [0, 1, 2, 3] (all unique)
uniq = [0, 1, 2, 3]
group_ends = [1, 2, 3, 4]
Pattern 0 (binary 00): no bits set → scatter_rows = []
Pattern 1 (binary 01): bit 0 set → row k-1-0 = 1 → scatter_rows = [1]
Pattern 2 (binary 10): bit 1 set → row k-1-1 = 0 → scatter_rows = [0]
Pattern 3 (binary 11): bits 0,1 set → rows 1, 0 → scatter_rows = [1, 0]
scatter_offsets = [0, 0, 1, 2, 4]
scatter_rows = [1, 0, 1, 0]
Inference:
v_perm = v[[3,1,0,2]] = [4, 2, 1, 3]
aggregated = [v_perm[0:1], v_perm[1:2], v_perm[2:3], v_perm[3:4]]
= [4, 2, 1, 3]
Scatter:
group 0 (agg=4): no rows → nothing
group 1 (agg=2): row 1 → out[1] += 2
group 2 (agg=1): row 0 → out[0] += 1
group 3 (agg=3): rows 1,0 → out[1] += 3, out[0] += 3
out[0:2] = [4, 5]
Block 1 (rows 2–3): analogous → out[2:4] computed similarly.
Verification: out[0] = M[0,:]·v = 1·1+0·2+1·3+0·4 = 4. ✓
For block b, the standard computation of output row r is:
out[b·k + r] = sum over j of M[b·k + r, j] * v[j]
RSR groups columns by their pattern. For group g (pattern pat), let C_g be the set of column indices in that group. All columns in C_g satisfy M[b·k + r, j] = pat[r] for all j ∈ C_g. Therefore:
sum_{j in C_g} M[b·k + r, j] * v[j]
= pat[r] * sum_{j in C_g} v[j]
= pat[r] * aggregated[g]
Summing over all groups reproduces the full dot product.
The algorithm above assumes M is square (n × n). In practice, weight matrices in neural networks are typically non-square: M is n_rows × n_cols, v has length n_cols, and the output y has length n_rows.
The generalization is straightforward — only the block dimensions change:
-
Block decomposition:
Mis split into horizontal blocks ofkrows each. BlockbisM[b·k : (b+1)·k, :]with shape(k, n_cols)instead of(k, n). The number of blocks isnum_blocks = n_rows / k. -
Column encoding: Each of the
n_colscolumns is encoded as ak-bit integer (unchanged logic, just iterating overn_colscolumns instead ofn). -
Permutation arrays: Each block's permutation
perms[b]hasn_colsentries (one per column), notn. -
Input vector:
vhas lengthn_cols. The permuted vectorv_perm[b]also hasn_colsentries. -
Output vector:
yhas lengthn_rows = num_blocks · k. The scatter phase writes to rows within eachk-sized block, concatenated to form the full output. -
Row padding: If
n_rowsis not divisible byk, padMwith zero rows to the nearest multiple ofkand trim the output back ton_rowsafter inference.
Everything else — the aggregate-then-scatter pattern, the counting sort, the group structure — remains identical. The inference kernel does not need to know n_rows at all; it only needs n_cols (as the permutation stride), k, and num_blocks.
Compute y = M · v where M is an n × n ternary matrix with entries in {-1, 0, +1}.
A ternary matrix M can be decomposed into two binary matrices:
M_pos[i, j] = 1 if M[i, j] = +1, else 0
M_neg[i, j] = 1 if M[i, j] = -1, else 0
These satisfy M = M_pos - M_neg, so:
y = M · v = M_pos · v − M_neg · v
Each of M_pos and M_neg is a binary {0, 1} matrix, so the standard binary RSR algorithm applies to each independently.
Rather than maintaining two separate RSR data structures, the ternary extension uses a single 2k-bit column encoding per block:
For block b (rows b·k to (b+1)·k - 1), each column j is encoded as a 2k-bit integer:
ternary_code[j] = (pos_pattern << k) | neg_pattern
where:
pos_pattern: k-bit integer encoding which rows haveM[i, j] = +1(same MSB convention as binary RSR)neg_pattern: k-bit integer encoding which rows haveM[i, j] = -1
This encoding has at most 3^k unique values per block (since each row position can be +1, -1, or 0). Columns are sorted by their ternary code, grouped by unique code, and then:
- Aggregate: sum
v[j]for all columnsjin the same group (identical to binary RSR) - Signed Scatter: for each set bit in
pos_pattern, add the aggregate to the output row; for each set bit inneg_pattern, subtract it
M_pos = (M == +1) # binary
M_neg = (M == -1) # binary
bit_weights = [2^(k-1), ..., 2^0]
pos_pattern = bit_weights @ M_pos[b*k:(b+1)*k, :]
neg_pattern = bit_weights @ M_neg[b*k:(b+1)*k, :]
ternary_code = (pos_pattern << k) | neg_pattern
perm = argsort(ternary_code) # or counting sort
# group_ends, scatter_rows, scatter_signs computed from unique codesfor each block b:
v_perm = v[perm] # permute input
for each group g:
agg = sum(v_perm[group_start:group_end]) # aggregate
for each (row, sign) in scatter_entries[g]:
out[b*k + row] += sign * agg # signed scatter| Name | Shape | Type | Description |
|---|---|---|---|
perms[b] |
(n,) |
int | Column permutation sorting by ternary code |
group_ends[b] |
(num_unique_b,) |
int | Cumulative group end indices |
scatter_offsets[b] |
(num_unique_b + 1,) |
int | Offsets into scatter arrays per group |
scatter_rows[b] |
(total_entries_b,) |
int | Output row indices (within block) |
scatter_signs[b] |
(total_entries_b,) |
int | +1 or -1 per scatter entry |
| Approach | Permutations | Aggregations | Scatter passes | Max unique patterns |
|---|---|---|---|---|
| Two binary RSRs | 2 per block | 2 per block | 2 per block | 2^k each |
| Unified ternary | 1 per block | 1 per block | 1 per block | 3^k |
The unified approach halves the number of permutation and aggregation passes at the cost of a larger pattern space. For typical k values (4–8), 3^k remains manageable (81–6561 patterns).
The non-square generalization described in the "Extension to Non-Square Matrices" section applies identically to the ternary case. For a ternary matrix M of shape (n_rows, n_cols), each block has n_cols columns encoded as 2k-bit integers, permutations have n_cols entries per block, and the output has n_rows elements. Row padding to a multiple of k is applied if needed.