Skip to content

Commit

Permalink
use subset of responses to determines early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
RuilinLi committed Dec 28, 2020
1 parent 1f31cc7 commit d56da36
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions R/basil.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' @importFrom magrittr %>%
#' @export
basil_base <- function(genotype.pfile, phe.file, responsid, covs, nlambda, lambda.min.ratio,
alpha, p.factor, configs, num_lambda_per_iter, num_to_add, max_num_to_add, fit.fun)
alpha, p.factor, configs, num_lambda_per_iter, num_to_add, max_num_to_add, fit.fun, stop_res)
{
time.start <- Sys.time()
responsid <- as.character(responsid)
Expand Down Expand Up @@ -61,6 +61,15 @@ basil_base <- function(genotype.pfile, phe.file, responsid, covs, nlambda, lambd
responsid <- responsid[!(responsid %in% id_to_remove)]
names(status) <- responsid
names(responses) <- responsid

if(!is.null(stop_res)){
stop_res = as.character(stop_res)
if(!all(stop_res %in% responsid)){
stop("stop_res must be NULL or one of the responses")
}
} else {
stop_res = responsid
}

K <- length(responsid) # Number of responses, this might change
if (is.null(alpha))
Expand Down Expand Up @@ -219,8 +228,8 @@ basil_base <- function(genotype.pfile, phe.file, responsid, covs, nlambda, lambd

# Use validation C-index to determine early stop
max_cindex <- Cval[, 1]
early_stop <- rep(FALSE, K)
names(early_stop) <- responsid
early_stop <- rep(FALSE, length(stop_res))
names(early_stop) <- stop_res
current_B <- result[[1]]
rownames(current_B) <- covs
colnames(current_B) <- responsid
Expand Down Expand Up @@ -392,18 +401,17 @@ basil_base <- function(genotype.pfile, phe.file, responsid, covs, nlambda, lambd
snpnetLoggerTimeDiff(sprintf("End metric evaluations for basil iteration %d.",
iter), time.basilmetric.start, indent = 3)
# Save temp result to files
save_list <- list(Ctrain = Ctrain, Cval = Cval, Ctest=Ctest, beta = out)
save_list <- list(Ctrain = Ctrain, Cval = Cval, Ctest=Ctest, beta = out, duration=Sys.time() - time.start)
save(save_list, file = file.path(configs[["save.dir"]], paste0("saveresult",
iter, ".RData")))

last_Cval_this_iter <- Cval[, (max_valid_index + local_valid)]
# max_Cval_this_iter = apply(Cval[,(max_valid_index +
# 1):(max_valid_index+local_valid), drop=F], 1, max) early_stop = early_stop |
# (last_Cval_this_iter < max_cindex)
early_stop <- early_stop | (last_Cval_this_iter < max_cindex)

# Don't stop too early
if (all(early_stop) && length(ever.active) > 3000)
early_stop <- early_stop | (last_Cval_this_iter[stop_res] < max_cindex[stop_res]- 0.005)

if (all(early_stop))
{
snpnetLoggerTimeDiff("Early stop for all responses reached.", time.start,
indent = 3)
Expand Down Expand Up @@ -446,9 +454,9 @@ basil_base <- function(genotype.pfile, phe.file, responsid, covs, nlambda, lambd
#' @export
basil <- function(genotype.pfile, phe.file, responsid, covs = NULL, nlambda = 100,
lambda.min.ratio = 0.01, alpha = NULL, p.factor = NULL, configs = NULL, num_lambda_per_iter = 10,
num_to_add = 1500, max_num_to_add = 6000)
num_to_add = 1500, max_num_to_add = 6000,stop_res=NULL)
{
basil_base(genotype.pfile, phe.file, responsid, covs, nlambda, lambda.min.ratio,
alpha, p.factor, configs, num_lambda_per_iter, num_to_add, max_num_to_add,
solve_aligned)
solve_aligned,stop_res)
}

0 comments on commit d56da36

Please sign in to comment.