Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ compute_scale <- function(mat, centering) {
.Call('_BiocSingular_compute_scale', PACKAGE = 'BiocSingular', mat, centering)
}

set_omp_threads <- function(nthreads) {
.Call('_BiocSingular_set_omp_threads', PACKAGE = 'BiocSingular', nthreads)
}

11 changes: 11 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,20 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// set_omp_threads
Rcpp::IntegerVector set_omp_threads(Rcpp::IntegerVector nthreads);
RcppExport SEXP _BiocSingular_set_omp_threads(SEXP nthreadsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::traits::input_parameter< Rcpp::IntegerVector >::type nthreads(nthreadsSEXP);
rcpp_result_gen = Rcpp::wrap(set_omp_threads(nthreads));
return rcpp_result_gen;
END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_BiocSingular_compute_scale", (DL_FUNC) &_BiocSingular_compute_scale, 2},
{"_BiocSingular_set_omp_threads", (DL_FUNC) &_BiocSingular_set_omp_threads, 1},
{NULL, NULL, 0}
};

Expand Down
62 changes: 44 additions & 18 deletions src/compute_scale.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#include "Rcpp.h"
#include "beachmat/numeric_matrix.h"
#include "beachmat/integer_matrix.h"
#include "beachmat/utils/const_column.h"

#include <cmath>
#include <omp.h>

template<class M>
Rcpp::NumericVector compute_scale_internal(Rcpp::RObject mat, Rcpp::RObject centering) {
Expand All @@ -24,28 +25,43 @@ Rcpp::NumericVector compute_scale_internal(Rcpp::RObject mat, Rcpp::RObject cent
}

Rcpp::NumericVector output(ncols);
beachmat::const_column<M> col_holder(ptr.get());

for (size_t i=0; i<ncols; ++i) {
col_holder.fill(i);
auto n=col_holder.get_n();
auto vals=col_holder.get_values();
int nth=omp_get_max_threads();
Rcpp::NumericVector battery(nrows*nth);
Rprintf("%i\n", nth);

double& current=output[i];
for (size_t j=0; j<n; ++j, ++vals) {
double val=*vals;
if (do_center) {
val-=numeric_centers[i];
}
current+=val*val;
}
#pragma omp parallel
{
auto holder=battery.begin() + omp_get_thread_num() * nrows;
decltype(ptr) pptr=NULL;

if (do_center && col_holder.is_sparse()) { // adding the contribution of the zeroes.
current += (nrows - n) * (numeric_centers[i] * numeric_centers[i]);
#pragma omp critical
if (nth==1) {
pptr=std::move(ptr);
} else {
pptr=ptr->clone();
}

current/=nrows-1;
current=std::sqrt(current);
#pragma omp for schedule(static)
for (size_t i=0; i<ncols; ++i) {
#pragma omp critical
{
pptr->get_col(i, holder);
}

double& current=output[i];
auto vals=holder;
for (size_t j=0; j<nrows; ++j, ++vals) {
double val=*vals;
if (do_center) {
val-=numeric_centers[i];
}
current+=val*val;
}

current/=nrows-1;
current=std::sqrt(current);
}
}

return output;
Expand All @@ -62,3 +78,13 @@ Rcpp::NumericVector compute_scale(Rcpp::RObject mat, Rcpp::RObject centering) {
return compute_scale_internal<beachmat::numeric_matrix>(mat, centering);
}
}

// [[Rcpp::export(rng=false)]]
Rcpp::IntegerVector set_omp_threads(Rcpp::IntegerVector nthreads)
{
if (nthreads.size()!=1L) {
Rf_error("'nthreads' must be integer(1)");
}
omp_set_num_threads(nthreads[0]);
return nthreads;
}