Skip to content

Commit 8647774

Browse files
committed
Store weights column positions to optimize matrix multiplication in masked descale
1 parent 4f30505 commit 8647774

2 files changed

Lines changed: 30 additions & 7 deletions

File tree

include/descale.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ typedef struct DescaleCore
9494
double *multiplied_weights;
9595
int *weights_left_idx;
9696
int *weights_right_idx;
97+
int *weights_top_idx;
98+
int *weights_bot_idx;
9799
int weights_columns;
98100
} DescaleCore;
99101

src/descale.c

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,8 @@ static inline int check_imask(unsigned char value) {
536536
return value >= 128;
537537
}
538538

539-
static void process_plane_masked(int dst_dim, int src_dim, int vector_count, enum DescaleDir dir, int bandwidth, int * restrict weights_left_idx, int * restrict weights_right_idx,
539+
static void process_plane_masked(int dst_dim, int src_dim, int vector_count, enum DescaleDir dir, int bandwidth,
540+
int * restrict weights_left_idx, int * restrict weights_right_idx, int * restrict weights_top_idx, int * restrict weights_bot_idx,
540541
int weights_columns, float * restrict weights, double * restrict multiplied_weights,
541542
int src_stride, int imask_stride, int dst_stride, const float * restrict srcp, const unsigned char * restrict imaskp, float * restrict dstp)
542543
{
@@ -583,11 +584,12 @@ static void process_plane_masked(int dst_dim, int src_dim, int vector_count, enu
583584
for (int j = imask_start; j < src_dim; j++) {
584585
if (!check_imask(imaskp[i * imuli + j * jmuli]))
585586
continue;
586-
for (int r = 0; r < dst_dim; r++) {
587-
if (j < weights_left_idx[r] || j >= weights_right_idx[r]) continue;
588-
for (int s = r; s < dst_dim; s++) {
589-
if (j < weights_left_idx[s] || j >= weights_right_idx[s]) continue;
590-
modified_ldlt[r * bandwidth + s - r] -= weights[r * weights_columns + j - weights_left_idx[r]] * weights[s * weights_columns + j - weights_left_idx[s]];
587+
int top = weights_top_idx[j];
588+
int bot = weights_bot_idx[j];
589+
for (int r = top; r < bot; r++) {
590+
double wr = weights[r * weights_columns + j - weights_left_idx[r]];
591+
for (int s = r; s < bot; s++) {
592+
modified_ldlt[r * bandwidth + s - r] -= wr * weights[s * weights_columns + j - weights_left_idx[s]];
591593
}
592594
}
593595
}
@@ -650,7 +652,8 @@ static void descale_process_vectors_c(struct DescaleCore *core, enum DescaleDir
650652
{
651653

652654
if (imaskp) {
653-
process_plane_masked(core->dst_dim, core->src_dim, vector_count, dir, core->bandwidth, core->weights_left_idx, core->weights_right_idx,
655+
process_plane_masked(core->dst_dim, core->src_dim, vector_count, dir, core->bandwidth,
656+
core->weights_left_idx, core->weights_right_idx, core->weights_top_idx, core->weights_bot_idx,
654657
core->weights_columns, core->weights, core->multiplied_weights, src_stride, imask_stride, dst_stride, srcp, imaskp, dstp);
655658
} else if (dir == DESCALE_DIR_HORIZONTAL) {
656659
if (core->bandwidth == 3)
@@ -716,6 +719,8 @@ static struct DescaleCore *create_core(int src_dim, int dst_dim, struct DescaleP
716719

717720
core.weights_left_idx = calloc(ceil_n(dst_dim, 8), sizeof (int));
718721
core.weights_right_idx = calloc(ceil_n(dst_dim, 8), sizeof (int));
722+
core.weights_top_idx = calloc(ceil_n(src_dim, 8), sizeof (int));
723+
core.weights_bot_idx = calloc(ceil_n(src_dim, 8), sizeof (int));
719724
for (int i = 0; i < dst_dim; i++) {
720725
for (int j = 0; j < src_dim; j++) {
721726
if (transposed_weights[i * src_dim + j] != 0.0) {
@@ -730,6 +735,20 @@ static struct DescaleCore *create_core(int src_dim, int dst_dim, struct DescaleP
730735
}
731736
}
732737
}
738+
for (int i = 0; i < src_dim; i++) {
739+
for (int j = 0; j < dst_dim; j++) {
740+
if (transposed_weights[j * src_dim + i] != 0.0) {
741+
core.weights_top_idx[i] = j;
742+
break;
743+
}
744+
}
745+
for (int j = dst_dim - 1; j >= 0; j--) {
746+
if (transposed_weights[j * src_dim + i] != 0.0) {
747+
core.weights_bot_idx[i] = j + 1;
748+
break;
749+
}
750+
}
751+
}
733752

734753
multiply_sparse_matrices(dst_dim, src_dim, core.weights_left_idx, core.weights_right_idx, transposed_weights, weights, &multiplied_weights);
735754

@@ -782,6 +801,8 @@ static void free_core(struct DescaleCore *core)
782801
free(core->weights);
783802
free(core->weights_left_idx);
784803
free(core->weights_right_idx);
804+
free(core->weights_top_idx);
805+
free(core->weights_bot_idx);
785806
free(core->multiplied_weights);
786807
free(core->diagonal);
787808
for (int i = 0; core->upper && i < core->bandwidth / 2; i++) {

0 commit comments

Comments
 (0)