Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 52 additions & 21 deletions R/optim_SQGDE.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#' @description Runs Stochastic Quasi-Gradient Differential Evolution (SQG-DE; Sala, Baldanzini, and Pierini, 2018) to minimize an objective function f(x). To maximize a function f(x), simply pass g(x)=-f(x) to ObjFun argument.
#' @param ObjFun A scalar-returning function to minimize whose first argument is a real-valued n_params-dimensional vector.
#' @param control_params control parameters for SQG-DE algo. see \code{\link{GetAlgoParams}} function documentation for more details. The only argument you NEED to pass here is n_params.
#' @param warm_start Optional. Output list from a previous call to \code{optim_SQGDE}. When provided, skips random initialization and seeds the population from the previous run's final particle state. \code{n_params} and \code{n_particles} in \code{control_params} must match the previous run.
#' @param ... additional arguments to pass ObjFun.
#' @return list containing solution and it's corresponding weight (i.e. f(solution)).
#' @export
Expand Down Expand Up @@ -61,7 +62,7 @@
#' par(old_par) # restore user graphic state
#'

optim_SQGDE = function(ObjFun, control_params = GetAlgoParams(), ...){
optim_SQGDE = function(ObjFun, control_params = GetAlgoParams(), warm_start = NULL, ...){

# create memory structures for storing particle trajectories
particles = array(NA,
Expand All @@ -72,33 +73,54 @@ optim_SQGDE = function(ObjFun, control_params = GetAlgoParams(), ...){
nrow = control_params$n_iters_per_particle,
ncol = control_params$n_particles)

# pop initialization
message('initalizing population...')
for(pmem_index in 1:control_params$n_particles){
count = 0 # establish a count variable to avoid infinite run time
while(weights[1,pmem_index]==Inf) {
particles[1, pmem_index, ] = pmax(control_params$lower,
pmin(control_params$upper,
stats::rnorm(control_params$n_params,
control_params$init_center,
control_params$init_sd)))
if (!is.null(warm_start)) {
# warm start: seed population from previous run
if (!all(c('last_particles', 'last_weights') %in% names(warm_start))) {
stop('ERROR: warm_start must be output of optim_SQGDE (missing last_particles or last_weights)')
}
if (!identical(dim(warm_start$last_particles),
c(control_params$n_particles, control_params$n_params))) {
stop(paste0('ERROR: warm_start$last_particles is ',
paste(dim(warm_start$last_particles), collapse = 'x'),
' but control_params expects ',
control_params$n_particles, 'x', control_params$n_params))
}
if (length(warm_start$last_weights) != control_params$n_particles) {
stop(paste0('ERROR: warm_start$last_weights length ', length(warm_start$last_weights),
' does not match control_params$n_particles ', control_params$n_particles))
}
particles[1, , ] = warm_start$last_particles
weights[1, ] = warm_start$last_weights
message('warm start: population seeded from previous run')
} else {
# pop initialization
message('initalizing population...')
for(pmem_index in 1:control_params$n_particles){
count = 0 # establish a count variable to avoid infinite run time
while(weights[1,pmem_index]==Inf) {
particles[1, pmem_index, ] = pmax(control_params$lower,
pmin(control_params$upper,
stats::rnorm(control_params$n_params,
control_params$init_center,
control_params$init_sd)))

weights[1, pmem_index] = ObjFun(particles[1, pmem_index, ], ...)
weights[1, pmem_index] = ObjFun(particles[1, pmem_index, ], ...)

# catcha NA's and Infinity and assign worst possible value
if(!is.finite(weights[1, pmem_index])){
weights[1, pmem_index] = Inf
}
count = count + 1
if(count>control_params$give_up_init){
stop('population initialization failed.
# catcha NA's and Infinity and assign worst possible value
if(!is.finite(weights[1, pmem_index])){
weights[1, pmem_index] = Inf
}
count = count + 1
if(count>control_params$give_up_init){
stop('population initialization failed.
inspect objective function or change init_center/init_sd to sample more
likely parameter values')
}
}
message(paste0(pmem_index, " / ", control_params$n_particles))
}
message(paste0(pmem_index, " / ", control_params$n_particles))
message('population initialization complete :)')
}
message('population initialization complete :)')

# assign adaption scheme
if(control_params$adapt_scheme=='rand'){
Expand Down Expand Up @@ -246,15 +268,24 @@ optim_SQGDE = function(ObjFun, control_params = GetAlgoParams(), ...){
minIdx = which.min(weights[iter_idx, ])
minEst = particles[iter_idx, minIdx, ]

last_particles = matrix(particles[iter_idx, , ],
nrow = control_params$n_particles,
ncol = control_params$n_params)
last_weights = weights[iter_idx, ]

if(control_params$return_trace==TRUE){
return(list('solution' = minEst,
'weight' = weights[iter_idx, minIdx],
'last_particles' = last_particles,
'last_weights' = last_weights,
'particles_trace' = particles,
'weights_trace' = weights,
'converged' = converge_test_passed))
} else {
return(list('solution' = minEst,
'weight' = weights[iter_idx, minIdx],
'last_particles' = last_particles,
'last_weights' = last_weights,
'converged' = converge_test_passed))
}
}
4 changes: 3 additions & 1 deletion man/optim_SQGDE.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions tests/testthat/test-optim_SQGDE.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ test_that("MLE factorizable MVN recovers sample means", {
)

expect_type(out, "list")
expect_named(out, c("solution", "weight", "converged"))
expect_named(out, c("solution", "weight", "last_particles", "last_weights", "converged"))
expect_length(out$solution, 4)
expect_true(all(is.finite(out$solution)))
expect_lt(max(abs(out$solution - analytic_mu)), 0.1)
Expand Down Expand Up @@ -69,7 +69,7 @@ test_that("univariate objective recovers sample mean", {
)

expect_type(out, "list")
expect_named(out, c("solution", "weight", "converged"))
expect_named(out, c("solution", "weight", "last_particles", "last_weights", "converged"))
expect_length(out$solution, 1)
expect_true(is.finite(out$solution))
expect_lt(abs(out$solution - analytic_mu), 0.1)
Expand Down Expand Up @@ -189,7 +189,7 @@ test_that("PSOCK parallel execution returns valid solution", {
)

expect_type(out, "list")
expect_named(out, c("solution", "weight", "converged"))
expect_named(out, c("solution", "weight", "last_particles", "last_weights", "converged"))
expect_length(out$solution, 4)
expect_true(all(is.finite(out$solution)))
expect_lt(max(abs(out$solution - analytic_mu)), 0.1)
Expand Down Expand Up @@ -235,7 +235,7 @@ test_that("FORK parallel execution returns valid solution", {
)

expect_type(out, "list")
expect_named(out, c("solution", "weight", "converged"))
expect_named(out, c("solution", "weight", "last_particles", "last_weights", "converged"))
expect_length(out$solution, 4)
expect_true(all(is.finite(out$solution)))
expect_lt(max(abs(out$solution - analytic_mu)), 0.1)
Expand Down
160 changes: 160 additions & 0 deletions tests/testthat/test-warm_start.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
library(graDiEnt)

obj_2d <- function(x) sum((x - c(1, 2))^2)

# ── output always contains last_particles and last_weights ────────────────────

test_that("optim_SQGDE returns last_particles and last_weights without return_trace", {
set.seed(1)
out <- suppressMessages(
optim_SQGDE(obj_2d,
GetAlgoParams(n_params = 2, n_iter = 50, n_particles = 8, n_diff = 2))
)

expect_true("last_particles" %in% names(out))
expect_true("last_weights" %in% names(out))
expect_equal(dim(out$last_particles), c(8L, 2L))
expect_length(out$last_weights, 8)
expect_true(all(is.finite(out$last_particles)))
expect_true(all(is.finite(out$last_weights)))
})

test_that("optim_SQGDE returns last_particles and last_weights with return_trace", {
set.seed(1)
out <- suppressMessages(
optim_SQGDE(obj_2d,
GetAlgoParams(n_params = 2, n_iter = 50, n_particles = 8,
n_diff = 2, return_trace = TRUE))
)

expect_true("last_particles" %in% names(out))
expect_true("last_weights" %in% names(out))
expect_equal(dim(out$last_particles), c(8L, 2L))
expect_length(out$last_weights, 8)
})

# ── warm start produces valid output ─────────────────────────────────────────

test_that("warm start returns valid solution with correct structure", {
set.seed(1)
out1 <- suppressMessages(
optim_SQGDE(obj_2d,
GetAlgoParams(n_params = 2, n_iter = 100, n_particles = 8,
n_diff = 2, init_sd = 1))
)

set.seed(2)
out2 <- suppressMessages(
optim_SQGDE(obj_2d,
GetAlgoParams(n_params = 2, n_iter = 100, n_particles = 8, n_diff = 2),
warm_start = out1)
)

expect_type(out2, "list")
expect_named(out2, c("solution", "weight", "last_particles", "last_weights", "converged"))
expect_length(out2$solution, 2)
expect_true(all(is.finite(out2$solution)))
expect_true(is.finite(out2$weight))
})

# ── warm start does not degrade solution ─────────────────────────────────────

test_that("warm start weight is no worse than cold start weight", {
set.seed(42)
out1 <- suppressMessages(
optim_SQGDE(obj_2d,
GetAlgoParams(n_params = 2, n_iter = 200, n_particles = 10,
n_diff = 2, init_sd = 1))
)

out2 <- suppressMessages(
optim_SQGDE(obj_2d,
GetAlgoParams(n_params = 2, n_iter = 200, n_particles = 10, n_diff = 2),
warm_start = out1)
)

expect_lte(out2$weight, out1$weight)
})

# ── chained warm starts converge to solution ─────────────────────────────────

test_that("chained warm starts converge closer to true optimum", {
set.seed(7)
cp <- GetAlgoParams(n_params = 2, n_iter = 150, n_particles = 10,
n_diff = 2, init_sd = 2)

out <- suppressMessages(optim_SQGDE(obj_2d, cp))
for (i in 1:3) {
out <- suppressMessages(
optim_SQGDE(obj_2d,
GetAlgoParams(n_params = 2, n_iter = 150, n_particles = 10, n_diff = 2),
warm_start = out)
)
}

expect_lt(max(abs(out$solution - c(1, 2))), 0.1)
})

# ── warm start works with bounds ──────────────────────────────────────────────

test_that("warm start respects bounds from new control_params", {
set.seed(5)
out1 <- suppressMessages(
optim_SQGDE(obj_2d,
GetAlgoParams(n_params = 2, n_iter = 100, n_particles = 8,
n_diff = 2, lower = -5, upper = 5))
)

out2 <- suppressMessages(
optim_SQGDE(obj_2d,
GetAlgoParams(n_params = 2, n_iter = 100, n_particles = 8,
n_diff = 2, lower = -5, upper = 5),
warm_start = out1)
)

expect_true(all(out2$solution >= -5))
expect_true(all(out2$solution <= 5))
expect_lte(out2$weight, out1$weight)
})

# ── validation errors ─────────────────────────────────────────────────────────

test_that("warm_start missing last_particles or last_weights throws error", {
bad <- list(solution = c(0, 0), weight = 1)
expect_error(
optim_SQGDE(obj_2d,
GetAlgoParams(n_params = 2, n_iter = 50, n_particles = 8, n_diff = 2),
warm_start = bad),
"missing last_particles or last_weights"
)
})

test_that("warm_start n_particles mismatch throws error", {
set.seed(1)
out1 <- suppressMessages(
optim_SQGDE(obj_2d,
GetAlgoParams(n_params = 2, n_iter = 50, n_particles = 8, n_diff = 2))
)

expect_error(
optim_SQGDE(obj_2d,
GetAlgoParams(n_params = 2, n_iter = 50, n_particles = 12, n_diff = 2),
warm_start = out1),
"control_params expects"
)
})

test_that("warm_start n_params mismatch throws error", {
set.seed(1)
out1 <- suppressMessages(
optim_SQGDE(obj_2d,
GetAlgoParams(n_params = 2, n_iter = 50, n_particles = 8, n_diff = 2))
)

expect_error(
optim_SQGDE(function(x) sum((x - 1:3)^2),
GetAlgoParams(n_params = 3, n_iter = 50, n_particles = 8, n_diff = 2),
warm_start = out1),
"control_params expects"
)
})
Loading