From 2b96230ec2efaa240f36a1ae0a1b8275d31cde69 Mon Sep 17 00:00:00 2001 From: Brendan Matthew Galdo Date: Sun, 10 May 2026 14:45:22 -0400 Subject: [PATCH 1/3] warm start feature --- R/optim_SQGDE.R | 73 ++++++++++---- tests/testthat/test-warm_start.R | 160 +++++++++++++++++++++++++++++++ 2 files changed, 212 insertions(+), 21 deletions(-) create mode 100644 tests/testthat/test-warm_start.R diff --git a/R/optim_SQGDE.R b/R/optim_SQGDE.R index 2ff385a..7fd8243 100644 --- a/R/optim_SQGDE.R +++ b/R/optim_SQGDE.R @@ -3,6 +3,7 @@ #' @description Runs Stochastic Quasi-Gradient Differential Evolution (SQG-DE; Sala, Baldanzini, and Pierini, 2018) to minimize an objective function f(x). To maximize a function f(x), simply pass g(x)=-f(x) to ObjFun argument. #' @param ObjFun A scalar-returning function to minimize whose first argument is a real-valued n_params-dimensional vector. #' @param control_params control parameters for SQG-DE algo. see \code{\link{GetAlgoParams}} function documentation for more details. The only argument you NEED to pass here is n_params. +#' @param warm_start Optional. Output list from a previous call to \code{optim_SQGDE}. When provided, skips random initialization and seeds the population from the previous run's final particle state. \code{n_params} and \code{n_particles} in \code{control_params} must match the previous run. #' @param ... additional arguments to pass ObjFun. #' @return list containing solution and it's corresponding weight (i.e. f(solution)). #' @export @@ -61,7 +62,7 @@ #' par(old_par) # restore user graphic state #' -optim_SQGDE = function(ObjFun, control_params = GetAlgoParams(), ...){ +optim_SQGDE = function(ObjFun, control_params = GetAlgoParams(), warm_start = NULL, ...){ # create memory structures for storing particle trajectories particles = array(NA, @@ -72,33 +73,54 @@ optim_SQGDE = function(ObjFun, control_params = GetAlgoParams(), ...){ nrow = control_params$n_iters_per_particle, ncol = control_params$n_particles) - # pop initialization - message('initalizing population...') - for(pmem_index in 1:control_params$n_particles){ - count = 0 # establish a count variable to avoid infinite run time - while(weights[1,pmem_index]==Inf) { - particles[1, pmem_index, ] = pmax(control_params$lower, - pmin(control_params$upper, - stats::rnorm(control_params$n_params, - control_params$init_center, - control_params$init_sd))) + if (!is.null(warm_start)) { + # warm start: seed population from previous run + if (!all(c('last_particles', 'last_weights') %in% names(warm_start))) { + stop('ERROR: warm_start must be output of optim_SQGDE (missing last_particles or last_weights)') + } + if (!identical(dim(warm_start$last_particles), + c(control_params$n_particles, control_params$n_params))) { + stop(paste0('ERROR: warm_start$last_particles is ', + paste(dim(warm_start$last_particles), collapse = 'x'), + ' but control_params expects ', + control_params$n_particles, 'x', control_params$n_params)) + } + if (length(warm_start$last_weights) != control_params$n_particles) { + stop(paste0('ERROR: warm_start$last_weights length ', length(warm_start$last_weights), + ' does not match control_params$n_particles ', control_params$n_particles)) + } + particles[1, , ] = warm_start$last_particles + weights[1, ] = warm_start$last_weights + message('warm start: population seeded from previous run') + } else { + # pop initialization + message('initalizing population...') + for(pmem_index in 1:control_params$n_particles){ + count = 0 # establish a count variable to avoid infinite run time + while(weights[1,pmem_index]==Inf) { + particles[1, pmem_index, ] = pmax(control_params$lower, + pmin(control_params$upper, + stats::rnorm(control_params$n_params, + control_params$init_center, + control_params$init_sd))) - weights[1, pmem_index] = ObjFun(particles[1, pmem_index, ], ...) + weights[1, pmem_index] = ObjFun(particles[1, pmem_index, ], ...) - # catcha NA's and Infinity and assign worst possible value - if(!is.finite(weights[1, pmem_index])){ - weights[1, pmem_index] = Inf - } - count = count + 1 - if(count>control_params$give_up_init){ - stop('population initialization failed. + # catcha NA's and Infinity and assign worst possible value + if(!is.finite(weights[1, pmem_index])){ + weights[1, pmem_index] = Inf + } + count = count + 1 + if(count>control_params$give_up_init){ + stop('population initialization failed. inspect objective function or change init_center/init_sd to sample more likely parameter values') + } } + message(paste0(pmem_index, " / ", control_params$n_particles)) } - message(paste0(pmem_index, " / ", control_params$n_particles)) + message('population initialization complete :)') } - message('population initialization complete :)') # assign adaption scheme if(control_params$adapt_scheme=='rand'){ @@ -246,15 +268,24 @@ optim_SQGDE = function(ObjFun, control_params = GetAlgoParams(), ...){ minIdx = which.min(weights[iter_idx, ]) minEst = particles[iter_idx, minIdx, ] + last_particles = matrix(particles[iter_idx, , ], + nrow = control_params$n_particles, + ncol = control_params$n_params) + last_weights = weights[iter_idx, ] + if(control_params$return_trace==TRUE){ return(list('solution' = minEst, 'weight' = weights[iter_idx, minIdx], + 'last_particles' = last_particles, + 'last_weights' = last_weights, 'particles_trace' = particles, 'weights_trace' = weights, 'converged' = converge_test_passed)) } else { return(list('solution' = minEst, 'weight' = weights[iter_idx, minIdx], + 'last_particles' = last_particles, + 'last_weights' = last_weights, 'converged' = converge_test_passed)) } } diff --git a/tests/testthat/test-warm_start.R b/tests/testthat/test-warm_start.R new file mode 100644 index 0000000..7f50db5 --- /dev/null +++ b/tests/testthat/test-warm_start.R @@ -0,0 +1,160 @@ +library(graDiEnt) + +obj_2d <- function(x) sum((x - c(1, 2))^2) + +# ── output always contains last_particles and last_weights ──────────────────── + +test_that("optim_SQGDE returns last_particles and last_weights without return_trace", { + set.seed(1) + out <- suppressMessages( + optim_SQGDE(obj_2d, + GetAlgoParams(n_params = 2, n_iter = 50, n_particles = 8, n_diff = 2)) + ) + + expect_true("last_particles" %in% names(out)) + expect_true("last_weights" %in% names(out)) + expect_equal(dim(out$last_particles), c(8L, 2L)) + expect_length(out$last_weights, 8) + expect_true(all(is.finite(out$last_particles))) + expect_true(all(is.finite(out$last_weights))) +}) + +test_that("optim_SQGDE returns last_particles and last_weights with return_trace", { + set.seed(1) + out <- suppressMessages( + optim_SQGDE(obj_2d, + GetAlgoParams(n_params = 2, n_iter = 50, n_particles = 8, + n_diff = 2, return_trace = TRUE)) + ) + + expect_true("last_particles" %in% names(out)) + expect_true("last_weights" %in% names(out)) + expect_equal(dim(out$last_particles), c(8L, 2L)) + expect_length(out$last_weights, 8) +}) + +# ── warm start produces valid output ───────────────────────────────────────── + +test_that("warm start returns valid solution with correct structure", { + set.seed(1) + out1 <- suppressMessages( + optim_SQGDE(obj_2d, + GetAlgoParams(n_params = 2, n_iter = 100, n_particles = 8, + n_diff = 2, init_sd = 1)) + ) + + set.seed(2) + out2 <- suppressMessages( + optim_SQGDE(obj_2d, + GetAlgoParams(n_params = 2, n_iter = 100, n_particles = 8, n_diff = 2), + warm_start = out1) + ) + + expect_type(out2, "list") + expect_named(out2, c("solution", "weight", "last_particles", "last_weights", "converged")) + expect_length(out2$solution, 2) + expect_true(all(is.finite(out2$solution))) + expect_true(is.finite(out2$weight)) +}) + +# ── warm start does not degrade solution ───────────────────────────────────── + +test_that("warm start weight is no worse than cold start weight", { + set.seed(42) + out1 <- suppressMessages( + optim_SQGDE(obj_2d, + GetAlgoParams(n_params = 2, n_iter = 200, n_particles = 10, + n_diff = 2, init_sd = 1)) + ) + + out2 <- suppressMessages( + optim_SQGDE(obj_2d, + GetAlgoParams(n_params = 2, n_iter = 200, n_particles = 10, n_diff = 2), + warm_start = out1) + ) + + expect_lte(out2$weight, out1$weight) +}) + +# ── chained warm starts converge to solution ───────────────────────────────── + +test_that("chained warm starts converge closer to true optimum", { + set.seed(7) + cp <- GetAlgoParams(n_params = 2, n_iter = 150, n_particles = 10, + n_diff = 2, init_sd = 2) + + out <- suppressMessages(optim_SQGDE(obj_2d, cp)) + for (i in 1:3) { + out <- suppressMessages( + optim_SQGDE(obj_2d, + GetAlgoParams(n_params = 2, n_iter = 150, n_particles = 10, n_diff = 2), + warm_start = out) + ) + } + + expect_lt(max(abs(out$solution - c(1, 2))), 0.1) +}) + +# ── warm start works with bounds ────────────────────────────────────────────── + +test_that("warm start respects bounds from new control_params", { + set.seed(5) + out1 <- suppressMessages( + optim_SQGDE(obj_2d, + GetAlgoParams(n_params = 2, n_iter = 100, n_particles = 8, + n_diff = 2, lower = -5, upper = 5)) + ) + + out2 <- suppressMessages( + optim_SQGDE(obj_2d, + GetAlgoParams(n_params = 2, n_iter = 100, n_particles = 8, + n_diff = 2, lower = -5, upper = 5), + warm_start = out1) + ) + + expect_true(all(out2$solution >= -5)) + expect_true(all(out2$solution <= 5)) + expect_lte(out2$weight, out1$weight) +}) + +# ── validation errors ───────────────────────────────────────────────────────── + +test_that("warm_start missing last_particles or last_weights throws error", { + bad <- list(solution = c(0, 0), weight = 1) + expect_error( + optim_SQGDE(obj_2d, + GetAlgoParams(n_params = 2, n_iter = 50, n_particles = 8, n_diff = 2), + warm_start = bad), + "missing last_particles or last_weights" + ) +}) + +test_that("warm_start n_particles mismatch throws error", { + set.seed(1) + out1 <- suppressMessages( + optim_SQGDE(obj_2d, + GetAlgoParams(n_params = 2, n_iter = 50, n_particles = 8, n_diff = 2)) + ) + + expect_error( + optim_SQGDE(obj_2d, + GetAlgoParams(n_params = 2, n_iter = 50, n_particles = 12, n_diff = 2), + warm_start = out1), + "control_params expects" + ) +}) + +test_that("warm_start n_params mismatch throws error", { + set.seed(1) + out1 <- suppressMessages( + optim_SQGDE(obj_2d, + GetAlgoParams(n_params = 2, n_iter = 50, n_particles = 8, n_diff = 2)) + ) + + expect_error( + optim_SQGDE(function(x) sum((x - 1:3)^2), + GetAlgoParams(n_params = 3, n_iter = 50, n_particles = 8, n_diff = 2), + warm_start = out1), + "control_params expects" + ) +}) From 1f25bef33c9cda72bc72e2139104d6c142102bd5 Mon Sep 17 00:00:00 2001 From: Brendan Matthew Galdo Date: Sun, 10 May 2026 14:46:14 -0400 Subject: [PATCH 2/3] update documentation --- man/optim_SQGDE.Rd | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/man/optim_SQGDE.Rd b/man/optim_SQGDE.Rd index 89a66a4..31e395d 100644 --- a/man/optim_SQGDE.Rd +++ b/man/optim_SQGDE.Rd @@ -4,13 +4,15 @@ \alias{optim_SQGDE} \title{optim_SQGDE} \usage{ -optim_SQGDE(ObjFun, control_params = GetAlgoParams(), ...) +optim_SQGDE(ObjFun, control_params = GetAlgoParams(), warm_start = NULL, ...) } \arguments{ \item{ObjFun}{A scalar-returning function to minimize whose first argument is a real-valued n_params-dimensional vector.} \item{control_params}{control parameters for SQG-DE algo. see \code{\link{GetAlgoParams}} function documentation for more details. The only argument you NEED to pass here is n_params.} +\item{warm_start}{Optional. Output list from a previous call to \code{optim_SQGDE}. When provided, skips random initialization and seeds the population from the previous run's final particle state. \code{n_params} and \code{n_particles} in \code{control_params} must match the previous run.} + \item{...}{additional arguments to pass ObjFun.} } \value{ From 76a4d12c7a7611ceeec6da631d568198e59f3bb3 Mon Sep 17 00:00:00 2001 From: Brendan Matthew Galdo Date: Sun, 10 May 2026 14:47:39 -0400 Subject: [PATCH 3/3] update failing test --- tests/testthat/test-optim_SQGDE.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test-optim_SQGDE.R b/tests/testthat/test-optim_SQGDE.R index a7228b8..59331cc 100644 --- a/tests/testthat/test-optim_SQGDE.R +++ b/tests/testthat/test-optim_SQGDE.R @@ -35,7 +35,7 @@ test_that("MLE factorizable MVN recovers sample means", { ) expect_type(out, "list") - expect_named(out, c("solution", "weight", "converged")) + expect_named(out, c("solution", "weight", "last_particles", "last_weights", "converged")) expect_length(out$solution, 4) expect_true(all(is.finite(out$solution))) expect_lt(max(abs(out$solution - analytic_mu)), 0.1) @@ -69,7 +69,7 @@ test_that("univariate objective recovers sample mean", { ) expect_type(out, "list") - expect_named(out, c("solution", "weight", "converged")) + expect_named(out, c("solution", "weight", "last_particles", "last_weights", "converged")) expect_length(out$solution, 1) expect_true(is.finite(out$solution)) expect_lt(abs(out$solution - analytic_mu), 0.1) @@ -189,7 +189,7 @@ test_that("PSOCK parallel execution returns valid solution", { ) expect_type(out, "list") - expect_named(out, c("solution", "weight", "converged")) + expect_named(out, c("solution", "weight", "last_particles", "last_weights", "converged")) expect_length(out$solution, 4) expect_true(all(is.finite(out$solution))) expect_lt(max(abs(out$solution - analytic_mu)), 0.1) @@ -235,7 +235,7 @@ test_that("FORK parallel execution returns valid solution", { ) expect_type(out, "list") - expect_named(out, c("solution", "weight", "converged")) + expect_named(out, c("solution", "weight", "last_particles", "last_weights", "converged")) expect_length(out$solution, 4) expect_true(all(is.finite(out$solution))) expect_lt(max(abs(out$solution - analytic_mu)), 0.1)