Skip to content
This repository was archived by the owner on Nov 20, 2025. It is now read-only.

Commit b29fe6c

Browse files
Merge pull request #39 from alexmccreight/master
fix standardization/intercept bug
2 parents 037bf1d + e8dc31a commit b29fe6c

3 files changed

Lines changed: 49 additions & 24 deletions

File tree

R/sufficient_stats_methods.R

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ initialize_susie_model.ss <- function(data, params, var_y, ...) {
4545
if (params$unmappable_effects %in% c("inf", "ash")) {
4646

4747
# Initialize omega quantities for unmappable effects
48-
omega_res <- compute_omega_quantities(data, tau2 = 0, sigma2 = 1)
48+
omega_res <- compute_omega_quantities(data, tau2 = 0, sigma2 = var_y)
4949
model$omega_var <- omega_res$omega_var
5050
model$predictor_weights <- omega_res$diagXtOmegaX
5151
model$XtOmegay <- data$eigen_vectors %*% (data$VtXty / omega_res$omega_var)
@@ -291,10 +291,8 @@ neg_loglik.ss <- function(data, params, model, V_param, ser_stats, ...) {
291291
#' @keywords internal
292292
update_fitted_values.ss <- function(data, params, model, l) {
293293
if (params$unmappable_effects != "none") {
294-
model$XtXr <- compute_Xb(data$XtX, colSums(model$alpha * model$mu) + model$theta)
294+
model$XtXr <- as.vector(data$XtX %*% (colSums(model$alpha * model$mu) + model$theta))
295295
} else {
296-
# Fix: Use direct matrix multiplication to match original implementation
297-
# Original: s$XtXr = s$XtXr + XtX %*% (s$alpha[l,] * s$mu[l,])
298296
model$XtXr <- model$fitted_without_l + as.vector(data$XtX %*% (model$alpha[l, ] * model$mu[l, ]))
299297
}
300298
return(model)
@@ -344,7 +342,7 @@ update_variance_components.ss <- function(data, params, model, ...) {
344342
))
345343
}
346344

347-
# Remove the sparse effects
345+
# Remove the sparse effects to compute residuals for mr.ash
348346
b <- colSums(model$alpha * model$mu)
349347
residuals <- data$y - data$X %*% b
350348

R/susie_constructors.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,18 @@ individual_data_constructor <- function(X, y, L = min(10, ncol(X)),
6767
}
6868
mean_y <- mean(y)
6969

70+
# Force required preprocessing for unmappable effects methods
71+
if (unmappable_effects != "none") {
72+
if (!intercept) {
73+
warning_message("Unmappable effects methods require centered data. Setting intercept=TRUE.")
74+
intercept <- TRUE
75+
}
76+
if (!standardize) {
77+
warning_message("Unmappable effects methods require scaled data. Setting standardize=TRUE.")
78+
standardize <- TRUE
79+
}
80+
}
81+
7082
# Handle null weights
7183
if (is.numeric(null_weight) && null_weight == 0) {
7284
null_weight <- NULL

R/susie_utils.R

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -596,8 +596,8 @@ add_null_effect <- function(model_init, V) {
596596
# and log Bayes factor calculations.
597597
#
598598
# Functions: compute_eigen_decomposition, add_eigen_decomposition,
599-
# compute_omega_quantities, compute_theta_blup, lbf_stabilization,
600-
# compute_posterior_weights, compute_lbf_gradient
599+
# compute_omega_quantities, scale_design_matrix, compute_theta_blup,
600+
# lbf_stabilization, compute_posterior_weights, compute_lbf_gradient
601601
# =============================================================================
602602

603603
# Compute eigenvalue decomposition for unmappable methods
@@ -617,19 +617,6 @@ compute_eigen_decomposition <- function(XtX, n) {
617617
# Add eigen decomposition to ss data objects for unmappable methods
618618
#' @keywords internal
619619
add_eigen_decomposition <- function(data, params, individual_data = NULL) {
620-
# Standardize y to unit variance for all unmappable effects methods
621-
y_scale_factor <- 1
622-
623-
if (params$unmappable_effects != "none") {
624-
var_y <- data$yty / (data$n - 1)
625-
if (abs(var_y - 1) > 1e-10) {
626-
sd_y <- sqrt(var_y)
627-
data$yty <- data$yty / var_y
628-
data$Xty <- data$Xty / sd_y
629-
y_scale_factor <- sd_y
630-
}
631-
}
632-
633620
# Compute eigen decomposition
634621
eigen_decomp <- compute_eigen_decomposition(data$XtX, data$n)
635622

@@ -638,19 +625,47 @@ add_eigen_decomposition <- function(data, params, individual_data = NULL) {
638625
data$eigen_values <- eigen_decomp$Dsq
639626
data$VtXty <- t(eigen_decomp$V) %*% data$Xty
640627

641-
# SuSiE.ash requires the X matrix and standardized y vector
642628
if (params$unmappable_effects == "ash") {
643629
if (is.null(individual_data)) {
644630
stop("Adaptive shrinkage (ash) requires individual-level data")
645631
}
646-
data$X <- individual_data$X
647-
data$y <- individual_data$y / y_scale_factor
648-
data$VtXt <- t(data$eigen_vectors) %*% t(individual_data$X)
632+
633+
X_scaled <- scale_design_matrix(
634+
individual_data$X,
635+
center = attr(individual_data$X, "scaled:center"),
636+
scale = attr(individual_data$X, "scaled:scale")
637+
)
638+
639+
data$X <- X_scaled
640+
data$y <- individual_data$y
641+
data$VtXt <- t(data$eigen_vectors) %*% t(X_scaled)
649642
}
650643

651644
return(data)
652645
}
653646

647+
#' Scale design matrix using centering and scaling parameters
648+
#'
649+
#' Applies column-wise centering and scaling to match the space used by
650+
#' compute_XtX() and compute_Xty() for unmappable effects methods.
651+
#'
652+
#' @param X Matrix to scale (n × p)
653+
#' @param center Vector of column means to subtract (length p), or NULL
654+
#' @param scale Vector of column SDs to divide by (length p), or NULL
655+
#'
656+
#' @return Scaled matrix with centered and scaled columns
657+
#'
658+
#' @keywords internal
659+
scale_design_matrix <- function(X, center = NULL, scale = NULL) {
660+
if (is.null(center)) center <- rep(0, ncol(X))
661+
if (is.null(scale)) scale <- rep(1, ncol(X))
662+
663+
X_centered <- sweep(X, 2, center, "-")
664+
X_scaled <- sweep(X_centered, 2, scale, "/")
665+
666+
return(X_scaled)
667+
}
668+
654669
# Compute Omega-weighted quantities for unmappable effects methods
655670
#' @keywords internal
656671
compute_omega_quantities <- function(data, tau2, sigma2) {

0 commit comments

Comments
 (0)