Skip to content

Commit 7042b71

Browse files
committed
Store weights column positions to optimize matrix multiplication in masked descale
1 parent 874228f commit 7042b71

2 files changed

Lines changed: 30 additions & 7 deletions

File tree

include/descale.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ typedef struct DescaleCore
9393
double *multiplied_weights;
9494
int *weights_left_idx;
9595
int *weights_right_idx;
96+
int *weights_top_idx;
97+
int *weights_bot_idx;
9698
int weights_columns;
9799
} DescaleCore;
98100

src/descale.c

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,8 @@ static inline int check_imask(unsigned char value) {
538538
return value >= 128;
539539
}
540540

541-
static void process_plane_masked(int dst_dim, int src_dim, int vector_count, enum DescaleDir dir, int bandwidth, int * restrict weights_left_idx, int * restrict weights_right_idx,
541+
static void process_plane_masked(int dst_dim, int src_dim, int vector_count, enum DescaleDir dir, int bandwidth,
542+
int * restrict weights_left_idx, int * restrict weights_right_idx, int * restrict weights_top_idx, int * restrict weights_bot_idx,
542543
int weights_columns, float * restrict weights, double * restrict multiplied_weights,
543544
int src_stride, int imask_stride, int dst_stride, const float * restrict srcp, const unsigned char * restrict imaskp, float * restrict dstp)
544545
{
@@ -585,11 +586,12 @@ static void process_plane_masked(int dst_dim, int src_dim, int vector_count, enu
585586
for (int j = imask_start; j < src_dim; j++) {
586587
if (!check_imask(imaskp[i * imuli + j * jmuli]))
587588
continue;
588-
for (int r = 0; r < dst_dim; r++) {
589-
if (j < weights_left_idx[r] || j >= weights_right_idx[r]) continue;
590-
for (int s = r; s < dst_dim; s++) {
591-
if (j < weights_left_idx[s] || j >= weights_right_idx[s]) continue;
592-
modified_ldlt[r * bandwidth + s - r] -= weights[r * weights_columns + j - weights_left_idx[r]] * weights[s * weights_columns + j - weights_left_idx[s]];
589+
int top = weights_top_idx[j];
590+
int bot = weights_bot_idx[j];
591+
for (int r = top; r < bot; r++) {
592+
double wr = weights[r * weights_columns + j - weights_left_idx[r]];
593+
for (int s = r; s < bot; s++) {
594+
modified_ldlt[r * bandwidth + s - r] -= wr * weights[s * weights_columns + j - weights_left_idx[s]];
593595
}
594596
}
595597
}
@@ -652,7 +654,8 @@ static void descale_process_vectors_c(struct DescaleCore *core, enum DescaleDir
652654
{
653655

654656
if (imaskp) {
655-
process_plane_masked(core->dst_dim, core->src_dim, vector_count, dir, core->bandwidth, core->weights_left_idx, core->weights_right_idx,
657+
process_plane_masked(core->dst_dim, core->src_dim, vector_count, dir, core->bandwidth,
658+
core->weights_left_idx, core->weights_right_idx, core->weights_top_idx, core->weights_bot_idx,
656659
core->weights_columns, core->weights, core->multiplied_weights, src_stride, imask_stride, dst_stride, srcp, imaskp, dstp);
657660
} else if (dir == DESCALE_DIR_HORIZONTAL) {
658661
if (core->bandwidth == 3)
@@ -718,6 +721,8 @@ static struct DescaleCore *create_core(int src_dim, int dst_dim, struct DescaleP
718721

719722
core.weights_left_idx = calloc(ceil_n(dst_dim, 8), sizeof (int));
720723
core.weights_right_idx = calloc(ceil_n(dst_dim, 8), sizeof (int));
724+
core.weights_top_idx = calloc(ceil_n(src_dim, 8), sizeof (int));
725+
core.weights_bot_idx = calloc(ceil_n(src_dim, 8), sizeof (int));
721726
for (int i = 0; i < dst_dim; i++) {
722727
for (int j = 0; j < src_dim; j++) {
723728
if (transposed_weights[i * src_dim + j] != 0.0) {
@@ -732,6 +737,20 @@ static struct DescaleCore *create_core(int src_dim, int dst_dim, struct DescaleP
732737
}
733738
}
734739
}
740+
for (int i = 0; i < src_dim; i++) {
741+
for (int j = 0; j < dst_dim; j++) {
742+
if (transposed_weights[j * src_dim + i] != 0.0) {
743+
core.weights_top_idx[i] = j;
744+
break;
745+
}
746+
}
747+
for (int j = dst_dim - 1; j >= 0; j--) {
748+
if (transposed_weights[j * src_dim + i] != 0.0) {
749+
core.weights_bot_idx[i] = j + 1;
750+
break;
751+
}
752+
}
753+
}
735754

736755
multiply_sparse_matrices(dst_dim, src_dim, core.weights_left_idx, core.weights_right_idx, transposed_weights, weights, &multiplied_weights);
737756

@@ -784,6 +803,8 @@ static void free_core(struct DescaleCore *core)
784803
free(core->weights);
785804
free(core->weights_left_idx);
786805
free(core->weights_right_idx);
806+
free(core->weights_top_idx);
807+
free(core->weights_bot_idx);
787808
free(core->multiplied_weights);
788809
free(core->diagonal);
789810
for (int i = 0; core->upper && i < core->bandwidth / 2; i++) {

0 commit comments

Comments
 (0)