Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions R/GetAlgoParams.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#' @param init_center A scalar or n_params-dimensional numeric vector, determines the mean of the Gaussian initialization distribution. The default value is 0.
#' @param n_cores_use An integer specifying the number of cores used when using parallelization. The default value is 1.
#' @param step_size A positive scalar, jump size or "F" in the DE crossover step notation. The default value is 2.38/sqrt(2*n_params).
#' @param jitter_size A positive scalar that determines the jitter (noise) size. Noise is added during adaption step from Uniform(-jitter_size,jitter_size) distribution. 1e-6 is the default value. Set to 0 to turn off jitter.
#' @param jitter_size A non-negative scalar that determines the jitter (noise) size. Noise is added during adaption step from Uniform(-jitter_size,jitter_size) distribution. 1e-6 is the default value. Set to 0 to turn off jitter.
#' @param parallel_type A string specifying parallelization type. 'none','FORK', or 'PSOCK' are valid values. 'none' is default value. 'FORK' does not work with Windows OS.
#' @param recovery_path A character scalar giving a file path where partial results are written via \code{saveRDS} periodically. Allows recovery of progress if the run is interrupted or crashes. Load with \code{readRDS(recovery_path)}. The saved object matches the structure of the return value of \code{optim_SQGDE} but excludes trace arrays. Set to \code{NULL} (default) to disable.
#' @param recovery_freq A positive integer controlling how often the recovery file is written. The file is saved every \code{recovery_freq} iterations. Default is 1 (every iteration). Ignored when \code{recovery_path} is \code{NULL}.
Expand Down Expand Up @@ -54,11 +54,12 @@ GetAlgoParams = function(n_params,
trace_print_freq = 100){
# n_params
### catch errors
if(length(n_params) > 1 || !is.finite(n_params)){
stop('ERROR: n_params must be a positive finite integer scalar')
}
n_params = as.integer(n_params)
if(any(!is.finite(n_params))){
stop('ERROR: n_params is not finite')
} else if( n_params<1 | length(n_params)>1){
stop('ERROR: n_params must be a postitive integer scalar')
if(n_params < 1){
stop('ERROR: n_params must be a positive finite integer scalar')
}

# n_particles
Expand All @@ -67,11 +68,12 @@ GetAlgoParams = function(n_params,
n_particles = max(3*n_params,4)
}
### catch errors
if(length(n_particles) > 1 || !is.finite(n_particles)){
stop('ERROR: n_particles must be a positive finite integer scalar, and at least 4')
}
n_particles = as.integer(n_particles)
if(any(!is.finite(n_particles))){
stop('ERROR: n_particles is not finite')
} else if( n_particles<4 | length(n_particles)>1){
stop('ERROR: n_particles must be a postitive integer scalar, and atleast 4')
if(n_particles < 4){
stop('ERROR: n_particles must be a positive finite integer scalar, and at least 4')
}

# n_iter
Expand All @@ -80,11 +82,12 @@ GetAlgoParams = function(n_params,
n_iter = 1000
}
### catch errors
if(length(n_iter) > 1 || !is.finite(n_iter)){
stop('ERROR: n_iter must be a positive finite integer scalar, and at least 4')
}
n_iter = as.integer(n_iter)
if(any(!is.finite(n_iter))){
stop('ERROR: n_iter is not finite')
} else if( n_iter<4 | length(n_iter)>1){
stop('ERROR: n_iter must be a postitive integer scalar, and atleast 4')
if(n_iter < 4){
stop('ERROR: n_iter must be a positive finite integer scalar, and at least 4')
}

# init_sd
Expand Down Expand Up @@ -172,8 +175,8 @@ GetAlgoParams = function(n_params,
### catch any errors
if(any(!is.finite(jitter_size))){
stop('ERROR: jitter_size is not finite')
} else if(any(jitter_size<= 0 | is.complex(jitter_size))){
stop('ERROR: jitter_size must be positive and real-valued')
} else if(any(jitter_size < 0 | is.complex(jitter_size))){
stop('ERROR: jitter_size must be non-negative and real-valued')
} else if(!(length(jitter_size) == 1)){
stop('ERROR: jitter_size vector length must be 1 ')
}
Expand Down
2 changes: 1 addition & 1 deletion man/GetAlgoParams.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

298 changes: 298 additions & 0 deletions tests/testthat/test-GetAlgoParams.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
library(graDiEnt)

ALL_NAMES <- c("n_params", "n_particles", "n_iter", "init_sd", "init_center",
"lower", "upper", "bounds_type", "n_cores_use", "step_size",
"crossover_rate", "jitter_size", "parallel_type", "recovery_path",
"recovery_freq", "thin", "purify", "n_iters_per_particle",
"return_trace", "n_diff", "adapt_scheme", "give_up_init",
"stop_tol", "stop_check", "converge_crit", "trace_print_freq")

# ── output structure ──────────────────────────────────────────────────────────

test_that("GetAlgoParams returns list with all expected names", {
cp <- GetAlgoParams(n_params = 3)
expect_named(cp, ALL_NAMES, ignore.order = TRUE)
})

# ── default values ────────────────────────────────────────────────────────────

test_that("n_particles default is max(3*n_params, 4)", {
expect_equal(GetAlgoParams(n_params = 5)$n_particles, 15L)
expect_equal(GetAlgoParams(n_params = 1)$n_particles, 4L) # max(3,4)=4
})

test_that("n_iter default is 1000", {
expect_equal(GetAlgoParams(n_params = 2)$n_iter, 1000L)
})

test_that("step_size default is 2.38/sqrt(2*n_params)", {
cp <- GetAlgoParams(n_params = 4)
expect_equal(cp$step_size, 2.38 / sqrt(2 * 4))
})

test_that("adapt_scheme default is 'rand'", {
expect_equal(GetAlgoParams(n_params = 2)$adapt_scheme, "rand")
})

test_that("bounds_type default is 'reflect'", {
expect_equal(GetAlgoParams(n_params = 2)$bounds_type, "reflect")
})

test_that("lower/upper default to vectors of -Inf/Inf length n_params", {
cp <- GetAlgoParams(n_params = 3)
expect_equal(cp$lower, rep(-Inf, 3))
expect_equal(cp$upper, rep(Inf, 3))
})

test_that("n_iters_per_particle = floor(n_iter / thin)", {
cp <- GetAlgoParams(n_params = 2, n_iter = 100, thin = 3)
expect_equal(cp$n_iters_per_particle, floor(100 / 3))
})

test_that("converge_crit default is 'stdev'", {
expect_equal(GetAlgoParams(n_params = 2)$converge_crit, "stdev")
})

test_that("jitter_size default is 1e-6", {
expect_equal(GetAlgoParams(n_params = 2)$jitter_size, 1e-6)
})

test_that("crossover_rate default is 1", {
expect_equal(GetAlgoParams(n_params = 2)$crossover_rate, 1)
})

test_that("purify default is Inf", {
expect_equal(GetAlgoParams(n_params = 2)$purify, Inf)
})

test_that("recovery_path default is NULL", {
expect_null(GetAlgoParams(n_params = 2)$recovery_path)
})

test_that("recovery_freq default is 1", {
expect_equal(GetAlgoParams(n_params = 2)$recovery_freq, 1L)
})

# ── n_params validation ───────────────────────────────────────────────────────

test_that("n_params is coerced to integer", {
cp <- GetAlgoParams(n_params = 3.0)
expect_identical(cp$n_params, 3L)
})

test_that("n_params < 1 errors", {
expect_error(GetAlgoParams(n_params = 0), "n_params")
})

test_that("n_params non-finite errors", {
expect_error(GetAlgoParams(n_params = Inf), "n_params")
})

test_that("n_params length > 1 errors", {
expect_error(GetAlgoParams(n_params = c(2, 3)), "scalar")
})

# ── n_particles validation ────────────────────────────────────────────────────

test_that("n_particles < 4 errors", {
expect_error(GetAlgoParams(n_params = 2, n_particles = 3), "n_particles")
})

test_that("n_particles non-finite errors", {
expect_error(GetAlgoParams(n_params = 2, n_particles = Inf), "n_particles")
})

test_that("n_particles coerced to integer", {
cp <- GetAlgoParams(n_params = 2, n_particles = 8.0)
expect_identical(cp$n_particles, 8L)
})

# ── n_iter validation ─────────────────────────────────────────────────────────

test_that("n_iter < 4 errors", {
expect_error(GetAlgoParams(n_params = 2, n_iter = 3), "n_iter")
})

test_that("n_iter coerced to integer", {
expect_identical(GetAlgoParams(n_params = 2, n_iter = 100.0)$n_iter, 100L)
})

# ── n_diff validation ─────────────────────────────────────────────────────────

test_that("n_diff > n_particles/2 errors", {
expect_error(
GetAlgoParams(n_params = 2, n_particles = 6, n_diff = 4),
"n_diff"
)
})

test_that("n_diff < 1 errors", {
expect_error(GetAlgoParams(n_params = 2, n_diff = 0), "n_diff")
})

test_that("n_diff coerced to integer", {
expect_identical(GetAlgoParams(n_params = 2, n_diff = 2.0)$n_diff, 2L)
})

# ── init_sd validation ────────────────────────────────────────────────────────

test_that("init_sd <= 0 errors", {
expect_error(GetAlgoParams(n_params = 2, init_sd = 0), "init_sd")
expect_error(GetAlgoParams(n_params = 2, init_sd = -1), "init_sd")
})

test_that("init_sd non-finite errors", {
expect_error(GetAlgoParams(n_params = 2, init_sd = Inf), "init_sd")
})

test_that("init_sd wrong length errors", {
expect_error(GetAlgoParams(n_params = 2, init_sd = c(0.1, 0.2, 0.3)), "init_sd")
})

test_that("init_sd length n_params accepted", {
expect_no_error(GetAlgoParams(n_params = 3, init_sd = c(0.1, 0.2, 0.3)))
})

# ── init_center validation ────────────────────────────────────────────────────

test_that("init_center non-finite errors", {
expect_error(GetAlgoParams(n_params = 2, init_center = Inf), "init_center")
})

test_that("init_center wrong length errors", {
expect_error(GetAlgoParams(n_params = 2, init_center = c(0, 0, 0)), "init_center")
})

test_that("init_center length n_params accepted", {
expect_no_error(GetAlgoParams(n_params = 3, init_center = c(1, 2, 3)))
})

# ── step_size validation ──────────────────────────────────────────────────────

test_that("step_size <= 0 errors", {
expect_error(GetAlgoParams(n_params = 2, step_size = 0), "step_size")
expect_error(GetAlgoParams(n_params = 2, step_size = -1), "step_size")
})

test_that("step_size non-finite errors", {
expect_error(GetAlgoParams(n_params = 2, step_size = Inf), "step_size")
})

test_that("step_size length > 1 errors", {
expect_error(GetAlgoParams(n_params = 2, step_size = c(0.5, 0.5)), "step_size")
})

# ── jitter_size validation ────────────────────────────────────────────────────

test_that("jitter_size = 0 accepted (turns off jitter)", {
expect_no_error(GetAlgoParams(n_params = 2, jitter_size = 0))
})

test_that("jitter_size < 0 errors", {
expect_error(GetAlgoParams(n_params = 2, jitter_size = -1e-6), "jitter_size")
})

test_that("jitter_size non-finite errors", {
expect_error(GetAlgoParams(n_params = 2, jitter_size = Inf), "jitter_size")
})

test_that("jitter_size length > 1 errors", {
expect_error(GetAlgoParams(n_params = 2, jitter_size = c(1e-6, 1e-6)), "jitter_size")
})

# ── crossover_rate validation ─────────────────────────────────────────────────

test_that("crossover_rate = 0 accepted", {
expect_no_error(GetAlgoParams(n_params = 2, crossover_rate = 0))
})

test_that("crossover_rate = 1 accepted", {
expect_no_error(GetAlgoParams(n_params = 2, crossover_rate = 1))
})

test_that("crossover_rate > 1 errors", {
expect_error(GetAlgoParams(n_params = 2, crossover_rate = 1.1), "crossover_rate")
})

test_that("crossover_rate < 0 errors", {
expect_error(GetAlgoParams(n_params = 2, crossover_rate = -0.1), "crossover_rate")
})

# ── n_cores_use validation ────────────────────────────────────────────────────

test_that("n_cores_use < 1 errors", {
expect_error(GetAlgoParams(n_params = 2, n_cores_use = 0), "n_cores_use")
})

test_that("n_cores_use coerced to integer", {
expect_identical(GetAlgoParams(n_params = 2, n_cores_use = 2.0)$n_cores_use, 2L)
})

# ── parallel_type validation ──────────────────────────────────────────────────

test_that("invalid parallel_type errors", {
expect_error(GetAlgoParams(n_params = 2, parallel_type = "MPI"), "parallel_type")
})

test_that("valid parallel_type values accepted", {
expect_no_error(GetAlgoParams(n_params = 2, parallel_type = "none"))
expect_no_error(GetAlgoParams(n_params = 2, parallel_type = "PSOCK"))
expect_no_error(GetAlgoParams(n_params = 2, parallel_type = "FORK"))
})

# ── thin validation ───────────────────────────────────────────────────────────

test_that("thin < 1 errors", {
expect_error(GetAlgoParams(n_params = 2, thin = 0), "thin")
})

test_that("thin coerced to integer", {
expect_identical(GetAlgoParams(n_params = 2, thin = 5.0)$thin, 5L)
})

# ── purify validation ─────────────────────────────────────────────────────────

test_that("purify = Inf accepted", {
expect_no_error(GetAlgoParams(n_params = 2, purify = Inf))
})

test_that("purify positive integer accepted", {
expect_no_error(GetAlgoParams(n_params = 2, purify = 10))
})

test_that("purify < 1 errors", {
expect_error(GetAlgoParams(n_params = 2, purify = 0), "purify")
})

# ── give_up_init validation ───────────────────────────────────────────────────

test_that("give_up_init < 1 errors", {
expect_error(GetAlgoParams(n_params = 2, give_up_init = 0), "give_up_init")
})

# ── stop_check validation ─────────────────────────────────────────────────────

test_that("stop_check < 2 errors", {
expect_error(GetAlgoParams(n_params = 2, stop_check = 1), "stop_check")
})

# ── stop_tol validation ───────────────────────────────────────────────────────

test_that("stop_tol < 0 errors", {
expect_error(GetAlgoParams(n_params = 2, stop_tol = -1e-5), "stop_tol")
})

test_that("stop_tol = 0 accepted", {
expect_no_error(GetAlgoParams(n_params = 2, stop_tol = 0))
})

# ── converge_crit validation ──────────────────────────────────────────────────

test_that("invalid converge_crit errors", {
expect_error(GetAlgoParams(n_params = 2, converge_crit = "max"), "converge_crit")
})

test_that("valid converge_crit values accepted", {
expect_no_error(GetAlgoParams(n_params = 2, converge_crit = "stdev"))
expect_no_error(GetAlgoParams(n_params = 2, converge_crit = "percent"))
})
Loading