Skip to content
6 changes: 5 additions & 1 deletion .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,8 @@
^.*[.]Rproj.*$
^\.Rproj\.user$
^.Rhistory$
[~]
[~]
^R/cv.cureitlasso.R
^R/simulasso.R
^R/coxsplit.R
^R/run-simulations.R
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
inst/doc
inst/model-fits
inst/cure-model-fits
inst/cureit-simulation-results/*
inst/archive
docs/
revdep/
Expand Down
65 changes: 63 additions & 2 deletions R/cureitlasso.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,47 @@

# mu is penlanlity for cure
# lamda is penality for cox
# touch is posterior probabilty of being cured fr uncured (maybe don't need) (Entropy)
# NOTE maybe change name of alphs
# calculated linear peredictor of that iteration


#' Lasso for Cure Models
#'
#' @param t Survival times
#' @param d Event indicator
#' @param x data n x p matrix or data.frame
#' @param minmu.ratio minimum penalty value in the logistic model
#' @param minlambda.ratio minimum penalty value in the cox model
#' @param adaptive
#' @param length.grid
#' @param mus penalty for cure
#' @param lambdas is penalty for cox
#' @param tol posterior probability of being cured for uncured (maybe don't need) (Entropy)
#' @param maxiter
#' @param progress verbose iteration counter
#'
#' @return
#' - `all_fits`
#' - `t`,
#' - `tj` - unique event times sorted,
#' - `d`- event indicator,
#' - `x`- data n x p matrix or data.frame,
#' - `alpha0` fitcure$a0[length(fitcure$lambda)],
#' - `alpha` =alpha,
#' - `beta` =beta,
#' - `haz`=haz,
#' - `cumhaz`=cumhaz,
#' - `tau`=as.vector(tau),
#' - `predsurv`=as.vector(predsurv_iter),
#' - `predcure`=as.vector(predcure_iter),
#' - `fitcure`=fitcure,
#' - `fitcox` =fitcox
#' - `mus` - mus that were used as penalties for cure models in cross validation
#' - `lambda` - lambdas that were used as penalities in cix model cross validation
#' @export
#'
#' @examples
cureitlasso <- function(t,
d,
x,
Expand All @@ -15,6 +59,7 @@ cureitlasso <- function(t,
require(glmnet)
require(survival)

# initialize initial weights
tau <- matrix(0.5,nrow=length(t),ncol=2) # First column: cured; second column: uncured
tau[d==1,1] = 0
tau[d==1,2] = 1
Expand All @@ -24,6 +69,7 @@ cureitlasso <- function(t,

# CV: Be careful for fold split! Two replicates from the same subject should always be assigned to the same fold

#this is for later integreation with tuning
penalty.factor.cure=rep(1,ncol(x))
penalty.factor.cox=rep(1,ncol(x))

Expand Down Expand Up @@ -51,6 +97,8 @@ cureitlasso <- function(t,
#
# }

# So this section is finding the set of optimal hyperparemters to test ----
# tau is uncure probability
fitcure0 <- glmnet(x=do.call("rbind", rep(list(x), 2)),
y=lab_class,
family="binomial",
Expand All @@ -72,8 +120,10 @@ cureitlasso <- function(t,
lambdas <- fitcox0$lambda[idx_lambdas]
}


musmax <- fitcure0$lambda[1]
lambdasmax <- fitcox0$lambda[1]
# -----------------

predcure <- predict(fitcure0,newx=x,s=mus,type="response") #Prob of uncured
predsurvexp <- predict(fitcox0,newx=x,s=lambdas,type="response") # exp(betaX)
Expand Down Expand Up @@ -114,6 +164,7 @@ cureitlasso <- function(t,

fit[[i]] <- list()


for (j in 1:length(lambdas)){

tau <- d + (1-d) * (predcure[,i] * predsurv[,j])/( (1 - predcure[,i]) + predcure[,i] * predsurv[,j] )
Expand Down Expand Up @@ -319,14 +370,24 @@ cureitlasso <- function(t,
# num_alpha[i,j] <- sum(alpha!=0)
# num_beta[i,j] <- sum(beta!=0)

# names(fit[i][j]) <- paste0("mu_fit_", i, "_lambda_fit_", "j")

}


names(fit[[i]]) <- paste0("lambda_fit_", 1:length(lambdas))
}

return(list(fit=fit,
names(fit) <- paste0("mu_fit_", 1:length(mus))

return(list(fit = fit,
# num_alpha=num_alpha,
# num_beta=num_beta,
mus=mus,
lambdas=lambdas))

}
}

# tj - is unique event times sorted
# haz

22 changes: 17 additions & 5 deletions R/cv.cureitlasso.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Returns:
# fit -
# arg: length.grid -

cv.cureitlasso <- function(t,
d,
x,
minmu.ratio=0.05, # minimum penalty value in the logistic model
minlambda.ratio=0.05, # minimum penalty value in the cox model
adaptive=FALSE,
length.grid=10,
length.grid=10, # tells how many hyperparameters to test. default is 10 for each
nfolds=5,
tol=1e-2,
maxiter=100,
Expand All @@ -22,13 +26,15 @@ cv.cureitlasso <- function(t,
require(survcomp)
require(pracma)

# order data by times
order_idx <- order(t)
t <- t[order_idx]
d <- d[order_idx]
x <- x[order_idx,]

if (progress) print("Fitting by EM algorithm ...")

# main model with no cross validation
fit <- cureitlasso(t,d,x,
minmu.ratio,
minlambda.ratio,
Expand All @@ -38,14 +44,16 @@ cv.cureitlasso <- function(t,
lambdas=NULL,
tol,
maxiter,
progress)
progress = TRUE)

if (progress) print("Running cross validations ...")

if (is.null(seed)) seed <- as.numeric(Sys.Date())

# Fold split
foldid <- coxsplit(as.matrix(Surv(t,d)), nfolds)

# grid/array of values
cv_brier <- array(NA,dim=c(length.grid,length.grid,nfolds))

# Run CVs
Expand All @@ -57,6 +65,7 @@ cv.cureitlasso <- function(t,

if (progress) print(i)


cv.fit[[i]] <- cureitlasso(t[foldid != i],
d[foldid != i],
x[foldid != i,],
Expand Down Expand Up @@ -114,6 +123,7 @@ cv.cureitlasso <- function(t,

}

# calcualte trapazoidal area for brier score
cv_brier[j,k,i] <- trapz(tbrier,brier)


Expand All @@ -130,7 +140,7 @@ cv.cureitlasso <- function(t,

cv.fit <- foreach(i = 1:nfolds) %dopar% {

source("~/Projects/Whiting-Qin-cureit/cureit/R/cureitlasso.R")
source(here::here("R/cureitlasso.R"))

cureitlasso(t[foldid != i],
d[foldid != i],
Expand Down Expand Up @@ -206,6 +216,7 @@ cv.cureitlasso <- function(t,

}

# summarize - take average across all folds
cv_brier_mean <- apply(cv_brier,c(1,2),function(x) mean(x))
cv_brier_se <- apply(cv_brier,c(1,2),function(x) sd(x)/sqrt(nfolds))
# pheatmap::pheatmap(cv_brier_mean,cluster_cols = F,cluster_rows = F)
Expand All @@ -231,11 +242,12 @@ cv.cureitlasso <- function(t,
cv_brier_mean = cv_brier_mean,
cv_brier_se = cv_brier_se,
index = list(min=idxmin,
`1se`= idx1se),
`1se`= idx1se),
foldid = foldid,
selected = list(min=selectedmin,
`1se`=selected1se)
)
)

}

}
Loading