diff --git a/R/GetAlgoParams.R b/R/GetAlgoParams.R index 9438630..b660f22 100644 --- a/R/GetAlgoParams.R +++ b/R/GetAlgoParams.R @@ -9,7 +9,7 @@ #' @param init_center A scalar or n_params-dimensional numeric vector, determines the mean of the Gaussian initialization distribution. The default value is 0. #' @param n_cores_use An integer specifying the number of cores used when using parallelization. The default value is 1. #' @param step_size A positive scalar, jump size or "F" in the DE crossover step notation. The default value is 2.38/sqrt(2*n_params). -#' @param jitter_size A positive scalar that determines the jitter (noise) size. Noise is added during adaption step from Uniform(-jitter_size,jitter_size) distribution. 1e-6 is the default value. Set to 0 to turn off jitter. +#' @param jitter_size A non-negative scalar that determines the jitter (noise) size. Noise is added during adaption step from Uniform(-jitter_size,jitter_size) distribution. 1e-6 is the default value. Set to 0 to turn off jitter. #' @param parallel_type A string specifying parallelization type. 'none','FORK', or 'PSOCK' are valid values. 'none' is default value. 'FORK' does not work with Windows OS. #' @param recovery_path A character scalar giving a file path where partial results are written via \code{saveRDS} periodically. Allows recovery of progress if the run is interrupted or crashes. Load with \code{readRDS(recovery_path)}. The saved object matches the structure of the return value of \code{optim_SQGDE} but excludes trace arrays. Set to \code{NULL} (default) to disable. #' @param recovery_freq A positive integer controlling how often the recovery file is written. The file is saved every \code{recovery_freq} iterations. Default is 1 (every iteration). Ignored when \code{recovery_path} is \code{NULL}. @@ -54,11 +54,12 @@ GetAlgoParams = function(n_params, trace_print_freq = 100){ # n_params ### catch errors + if(length(n_params) > 1 || !is.finite(n_params)){ + stop('ERROR: n_params must be a positive finite integer scalar') + } n_params = as.integer(n_params) - if(any(!is.finite(n_params))){ - stop('ERROR: n_params is not finite') - } else if( n_params<1 | length(n_params)>1){ - stop('ERROR: n_params must be a postitive integer scalar') + if(n_params < 1){ + stop('ERROR: n_params must be a positive finite integer scalar') } # n_particles @@ -67,11 +68,12 @@ GetAlgoParams = function(n_params, n_particles = max(3*n_params,4) } ### catch errors + if(length(n_particles) > 1 || !is.finite(n_particles)){ + stop('ERROR: n_particles must be a positive finite integer scalar, and at least 4') + } n_particles = as.integer(n_particles) - if(any(!is.finite(n_particles))){ - stop('ERROR: n_particles is not finite') - } else if( n_particles<4 | length(n_particles)>1){ - stop('ERROR: n_particles must be a postitive integer scalar, and atleast 4') + if(n_particles < 4){ + stop('ERROR: n_particles must be a positive finite integer scalar, and at least 4') } # n_iter @@ -80,11 +82,12 @@ GetAlgoParams = function(n_params, n_iter = 1000 } ### catch errors + if(length(n_iter) > 1 || !is.finite(n_iter)){ + stop('ERROR: n_iter must be a positive finite integer scalar, and at least 4') + } n_iter = as.integer(n_iter) - if(any(!is.finite(n_iter))){ - stop('ERROR: n_iter is not finite') - } else if( n_iter<4 | length(n_iter)>1){ - stop('ERROR: n_iter must be a postitive integer scalar, and atleast 4') + if(n_iter < 4){ + stop('ERROR: n_iter must be a positive finite integer scalar, and at least 4') } # init_sd @@ -172,8 +175,8 @@ GetAlgoParams = function(n_params, ### catch any errors if(any(!is.finite(jitter_size))){ stop('ERROR: jitter_size is not finite') - } else if(any(jitter_size<= 0 | is.complex(jitter_size))){ - stop('ERROR: jitter_size must be positive and real-valued') + } else if(any(jitter_size < 0 | is.complex(jitter_size))){ + stop('ERROR: jitter_size must be non-negative and real-valued') } else if(!(length(jitter_size) == 1)){ stop('ERROR: jitter_size vector length must be 1 ') } diff --git a/man/GetAlgoParams.Rd b/man/GetAlgoParams.Rd index b069be5..4ee36a2 100644 --- a/man/GetAlgoParams.Rd +++ b/man/GetAlgoParams.Rd @@ -55,7 +55,7 @@ GetAlgoParams( \item{step_size}{A positive scalar, jump size or "F" in the DE crossover step notation. The default value is 2.38/sqrt(2*n_params).} -\item{jitter_size}{A positive scalar that determines the jitter (noise) size. Noise is added during adaption step from Uniform(-jitter_size,jitter_size) distribution. 1e-6 is the default value. Set to 0 to turn off jitter.} +\item{jitter_size}{A non-negative scalar that determines the jitter (noise) size. Noise is added during adaption step from Uniform(-jitter_size,jitter_size) distribution. 1e-6 is the default value. Set to 0 to turn off jitter.} \item{crossover_rate}{A numeric scalar on the interval [0,1]. Determines the probability a parameter on a chain is updated on a given crossover step, sampled from a Bernoulli distribution. When 0, exactly one randomly chosen parameter is updated per iteration. The default value is 1.} diff --git a/tests/testthat/test-GetAlgoParams.R b/tests/testthat/test-GetAlgoParams.R new file mode 100644 index 0000000..3bb3f38 --- /dev/null +++ b/tests/testthat/test-GetAlgoParams.R @@ -0,0 +1,298 @@ +library(graDiEnt) + +ALL_NAMES <- c("n_params", "n_particles", "n_iter", "init_sd", "init_center", + "lower", "upper", "bounds_type", "n_cores_use", "step_size", + "crossover_rate", "jitter_size", "parallel_type", "recovery_path", + "recovery_freq", "thin", "purify", "n_iters_per_particle", + "return_trace", "n_diff", "adapt_scheme", "give_up_init", + "stop_tol", "stop_check", "converge_crit", "trace_print_freq") + +# ── output structure ────────────────────────────────────────────────────────── + +test_that("GetAlgoParams returns list with all expected names", { + cp <- GetAlgoParams(n_params = 3) + expect_named(cp, ALL_NAMES, ignore.order = TRUE) +}) + +# ── default values ──────────────────────────────────────────────────────────── + +test_that("n_particles default is max(3*n_params, 4)", { + expect_equal(GetAlgoParams(n_params = 5)$n_particles, 15L) + expect_equal(GetAlgoParams(n_params = 1)$n_particles, 4L) # max(3,4)=4 +}) + +test_that("n_iter default is 1000", { + expect_equal(GetAlgoParams(n_params = 2)$n_iter, 1000L) +}) + +test_that("step_size default is 2.38/sqrt(2*n_params)", { + cp <- GetAlgoParams(n_params = 4) + expect_equal(cp$step_size, 2.38 / sqrt(2 * 4)) +}) + +test_that("adapt_scheme default is 'rand'", { + expect_equal(GetAlgoParams(n_params = 2)$adapt_scheme, "rand") +}) + +test_that("bounds_type default is 'reflect'", { + expect_equal(GetAlgoParams(n_params = 2)$bounds_type, "reflect") +}) + +test_that("lower/upper default to vectors of -Inf/Inf length n_params", { + cp <- GetAlgoParams(n_params = 3) + expect_equal(cp$lower, rep(-Inf, 3)) + expect_equal(cp$upper, rep(Inf, 3)) +}) + +test_that("n_iters_per_particle = floor(n_iter / thin)", { + cp <- GetAlgoParams(n_params = 2, n_iter = 100, thin = 3) + expect_equal(cp$n_iters_per_particle, floor(100 / 3)) +}) + +test_that("converge_crit default is 'stdev'", { + expect_equal(GetAlgoParams(n_params = 2)$converge_crit, "stdev") +}) + +test_that("jitter_size default is 1e-6", { + expect_equal(GetAlgoParams(n_params = 2)$jitter_size, 1e-6) +}) + +test_that("crossover_rate default is 1", { + expect_equal(GetAlgoParams(n_params = 2)$crossover_rate, 1) +}) + +test_that("purify default is Inf", { + expect_equal(GetAlgoParams(n_params = 2)$purify, Inf) +}) + +test_that("recovery_path default is NULL", { + expect_null(GetAlgoParams(n_params = 2)$recovery_path) +}) + +test_that("recovery_freq default is 1", { + expect_equal(GetAlgoParams(n_params = 2)$recovery_freq, 1L) +}) + +# ── n_params validation ─────────────────────────────────────────────────────── + +test_that("n_params is coerced to integer", { + cp <- GetAlgoParams(n_params = 3.0) + expect_identical(cp$n_params, 3L) +}) + +test_that("n_params < 1 errors", { + expect_error(GetAlgoParams(n_params = 0), "n_params") +}) + +test_that("n_params non-finite errors", { + expect_error(GetAlgoParams(n_params = Inf), "n_params") +}) + +test_that("n_params length > 1 errors", { + expect_error(GetAlgoParams(n_params = c(2, 3)), "scalar") +}) + +# ── n_particles validation ──────────────────────────────────────────────────── + +test_that("n_particles < 4 errors", { + expect_error(GetAlgoParams(n_params = 2, n_particles = 3), "n_particles") +}) + +test_that("n_particles non-finite errors", { + expect_error(GetAlgoParams(n_params = 2, n_particles = Inf), "n_particles") +}) + +test_that("n_particles coerced to integer", { + cp <- GetAlgoParams(n_params = 2, n_particles = 8.0) + expect_identical(cp$n_particles, 8L) +}) + +# ── n_iter validation ───────────────────────────────────────────────────────── + +test_that("n_iter < 4 errors", { + expect_error(GetAlgoParams(n_params = 2, n_iter = 3), "n_iter") +}) + +test_that("n_iter coerced to integer", { + expect_identical(GetAlgoParams(n_params = 2, n_iter = 100.0)$n_iter, 100L) +}) + +# ── n_diff validation ───────────────────────────────────────────────────────── + +test_that("n_diff > n_particles/2 errors", { + expect_error( + GetAlgoParams(n_params = 2, n_particles = 6, n_diff = 4), + "n_diff" + ) +}) + +test_that("n_diff < 1 errors", { + expect_error(GetAlgoParams(n_params = 2, n_diff = 0), "n_diff") +}) + +test_that("n_diff coerced to integer", { + expect_identical(GetAlgoParams(n_params = 2, n_diff = 2.0)$n_diff, 2L) +}) + +# ── init_sd validation ──────────────────────────────────────────────────────── + +test_that("init_sd <= 0 errors", { + expect_error(GetAlgoParams(n_params = 2, init_sd = 0), "init_sd") + expect_error(GetAlgoParams(n_params = 2, init_sd = -1), "init_sd") +}) + +test_that("init_sd non-finite errors", { + expect_error(GetAlgoParams(n_params = 2, init_sd = Inf), "init_sd") +}) + +test_that("init_sd wrong length errors", { + expect_error(GetAlgoParams(n_params = 2, init_sd = c(0.1, 0.2, 0.3)), "init_sd") +}) + +test_that("init_sd length n_params accepted", { + expect_no_error(GetAlgoParams(n_params = 3, init_sd = c(0.1, 0.2, 0.3))) +}) + +# ── init_center validation ──────────────────────────────────────────────────── + +test_that("init_center non-finite errors", { + expect_error(GetAlgoParams(n_params = 2, init_center = Inf), "init_center") +}) + +test_that("init_center wrong length errors", { + expect_error(GetAlgoParams(n_params = 2, init_center = c(0, 0, 0)), "init_center") +}) + +test_that("init_center length n_params accepted", { + expect_no_error(GetAlgoParams(n_params = 3, init_center = c(1, 2, 3))) +}) + +# ── step_size validation ────────────────────────────────────────────────────── + +test_that("step_size <= 0 errors", { + expect_error(GetAlgoParams(n_params = 2, step_size = 0), "step_size") + expect_error(GetAlgoParams(n_params = 2, step_size = -1), "step_size") +}) + +test_that("step_size non-finite errors", { + expect_error(GetAlgoParams(n_params = 2, step_size = Inf), "step_size") +}) + +test_that("step_size length > 1 errors", { + expect_error(GetAlgoParams(n_params = 2, step_size = c(0.5, 0.5)), "step_size") +}) + +# ── jitter_size validation ──────────────────────────────────────────────────── + +test_that("jitter_size = 0 accepted (turns off jitter)", { + expect_no_error(GetAlgoParams(n_params = 2, jitter_size = 0)) +}) + +test_that("jitter_size < 0 errors", { + expect_error(GetAlgoParams(n_params = 2, jitter_size = -1e-6), "jitter_size") +}) + +test_that("jitter_size non-finite errors", { + expect_error(GetAlgoParams(n_params = 2, jitter_size = Inf), "jitter_size") +}) + +test_that("jitter_size length > 1 errors", { + expect_error(GetAlgoParams(n_params = 2, jitter_size = c(1e-6, 1e-6)), "jitter_size") +}) + +# ── crossover_rate validation ───────────────────────────────────────────────── + +test_that("crossover_rate = 0 accepted", { + expect_no_error(GetAlgoParams(n_params = 2, crossover_rate = 0)) +}) + +test_that("crossover_rate = 1 accepted", { + expect_no_error(GetAlgoParams(n_params = 2, crossover_rate = 1)) +}) + +test_that("crossover_rate > 1 errors", { + expect_error(GetAlgoParams(n_params = 2, crossover_rate = 1.1), "crossover_rate") +}) + +test_that("crossover_rate < 0 errors", { + expect_error(GetAlgoParams(n_params = 2, crossover_rate = -0.1), "crossover_rate") +}) + +# ── n_cores_use validation ──────────────────────────────────────────────────── + +test_that("n_cores_use < 1 errors", { + expect_error(GetAlgoParams(n_params = 2, n_cores_use = 0), "n_cores_use") +}) + +test_that("n_cores_use coerced to integer", { + expect_identical(GetAlgoParams(n_params = 2, n_cores_use = 2.0)$n_cores_use, 2L) +}) + +# ── parallel_type validation ────────────────────────────────────────────────── + +test_that("invalid parallel_type errors", { + expect_error(GetAlgoParams(n_params = 2, parallel_type = "MPI"), "parallel_type") +}) + +test_that("valid parallel_type values accepted", { + expect_no_error(GetAlgoParams(n_params = 2, parallel_type = "none")) + expect_no_error(GetAlgoParams(n_params = 2, parallel_type = "PSOCK")) + expect_no_error(GetAlgoParams(n_params = 2, parallel_type = "FORK")) +}) + +# ── thin validation ─────────────────────────────────────────────────────────── + +test_that("thin < 1 errors", { + expect_error(GetAlgoParams(n_params = 2, thin = 0), "thin") +}) + +test_that("thin coerced to integer", { + expect_identical(GetAlgoParams(n_params = 2, thin = 5.0)$thin, 5L) +}) + +# ── purify validation ───────────────────────────────────────────────────────── + +test_that("purify = Inf accepted", { + expect_no_error(GetAlgoParams(n_params = 2, purify = Inf)) +}) + +test_that("purify positive integer accepted", { + expect_no_error(GetAlgoParams(n_params = 2, purify = 10)) +}) + +test_that("purify < 1 errors", { + expect_error(GetAlgoParams(n_params = 2, purify = 0), "purify") +}) + +# ── give_up_init validation ─────────────────────────────────────────────────── + +test_that("give_up_init < 1 errors", { + expect_error(GetAlgoParams(n_params = 2, give_up_init = 0), "give_up_init") +}) + +# ── stop_check validation ───────────────────────────────────────────────────── + +test_that("stop_check < 2 errors", { + expect_error(GetAlgoParams(n_params = 2, stop_check = 1), "stop_check") +}) + +# ── stop_tol validation ─────────────────────────────────────────────────────── + +test_that("stop_tol < 0 errors", { + expect_error(GetAlgoParams(n_params = 2, stop_tol = -1e-5), "stop_tol") +}) + +test_that("stop_tol = 0 accepted", { + expect_no_error(GetAlgoParams(n_params = 2, stop_tol = 0)) +}) + +# ── converge_crit validation ────────────────────────────────────────────────── + +test_that("invalid converge_crit errors", { + expect_error(GetAlgoParams(n_params = 2, converge_crit = "max"), "converge_crit") +}) + +test_that("valid converge_crit values accepted", { + expect_no_error(GetAlgoParams(n_params = 2, converge_crit = "stdev")) + expect_no_error(GetAlgoParams(n_params = 2, converge_crit = "percent")) +})