Skip to content

Commit f780e5f

Browse files
committed
trans_models_t: smaller model objs, tables
1 parent 91925ed commit f780e5f

7 files changed

Lines changed: 41 additions & 65 deletions

File tree

R/evoland_db_tables.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ evoland_db$set("active", "trans_models_t", function(x) {
182182
self,
183183
"trans_models_t",
184184
as_trans_models_t,
185-
key_cols = c("id_trans"),
185+
key_cols = c("id_trans", "fit_call"),
186186
map_cols = c("model_params", "goodness_of_fit")
187187
)(x)
188188
})

R/trans_models_glm.R

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,21 @@ fit_glm <- function(data, result_col = "result", ...) {
3030
formula_str <- paste(result_col, "~", paste(pred_cols, collapse = " + "))
3131
formula <- as.formula(formula_str)
3232

33-
# Fit GLM with quasibinomial family (handles overdispersion better)
3433
model <- glm(formula, data = data, family = quasibinomial())
3534

36-
# Butcher the model if package is available (reduces memory footprint)
35+
# clean up the object
36+
model[["model"]] <- NULL
37+
model[["residuals"]] <- NULL
38+
model[["fitted.values"]] <- NULL
39+
model[["effects"]] <- NULL
40+
model[["qr"]][["qr"]] <- NULL
41+
model[["linear.predictors"]] <- NULL
42+
model[["weights"]] <- NULL
43+
model[["prior.weights"]] <- NULL
44+
model[["y"]] <- NULL
45+
attr(model[["formula"]], ".Environment") <- NULL
46+
47+
# just to be sure
3748
if (requireNamespace("butcher", quietly = TRUE)) {
3849
model <- butcher::butcher(model)
3950
}

R/trans_models_rf.R

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
#' - min.node.size = 1
2828
#'
2929
#' @export
30-
fit_ranger <- function(data, result_col = "result", ...) {
30+
fit_ranger <- function(data, result_col = "result", num.trees = 500, ...) {
3131
if (!requireNamespace("ranger", quietly = TRUE)) {
3232
stop(
3333
"Package 'ranger' is required but is not installed.\n",
@@ -58,10 +58,7 @@ fit_ranger <- function(data, result_col = "result", ...) {
5858
x = x,
5959
y = y,
6060
num.trees = num.trees,
61-
min.node.size = min.min.node.size,
6261
case.weights = weights,
63-
# Stratified sampling by class
64-
sample.fraction = c(nmin / class_counts[1], nmin / class_counts[2]),
6562
probability = TRUE, # For probability predictions
6663
importance = "impurity",
6764
...

R/trans_models_t.R

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#' - `model_family`: Model family (e.g., "rf", "glm", "bayesian")
1414
#' - `model_params`: Map of model (hyper) parameters
1515
#' - `goodness_of_fit`: Map of various measures of fit (e.g., ROC AUC, RMSE)
16-
#' - `sampled_coords`: data.table with id_coord/id_period pairs used for training
1716
#' - `fit_call`: Character string of the original fit function call for reproducibility
1817
#' - `model_obj_part`: BLOB of serialized model object for validation
1918
#' - `model_obj_full`: BLOB of serialized model object for extrapolation
@@ -25,7 +24,6 @@ as_trans_models_t <- function(x) {
2524
model_family = character(0),
2625
model_params = list(),
2726
goodness_of_fit = list(),
28-
sampled_coords = list(),
2927
fit_call = character(0),
3028
model_obj_part = list(),
3129
model_obj_full = list()
@@ -144,9 +142,6 @@ evoland_db$set(
144142
train_data <- trans_pred_data_full[train_idx, ]
145143
test_data <- trans_pred_data_full[test_idx, ]
146144

147-
# Record sampled coordinates for reproducibility
148-
sampled_coords <- train_data[, .(id_coord, id_period)]
149-
150145
message(glue::glue(
151146
" Training on {nrow(train_data)} observations ",
152147
"({n_train_true} TRUE, {n_train_false} FALSE)"
@@ -184,8 +179,7 @@ evoland_db$set(
184179
goodness_of_fit <- gof_fun(
185180
model = model,
186181
test_data = test_data,
187-
result_col = "result",
188-
...
182+
result_col = "result"
189183
)
190184

191185
# Extract model family
@@ -201,7 +195,8 @@ evoland_db$set(
201195
model_params <- list(
202196
n_predictors = length(pred_cols),
203197
n_train = nrow(train_data),
204-
sample_pct = sample_pct
198+
sample_pct = sample_pct,
199+
...
205200
)
206201

207202
# Serialize partial model
@@ -218,7 +213,6 @@ evoland_db$set(
218213
model_family = model_family,
219214
model_params = list(model_params),
220215
goodness_of_fit = list(goodness_of_fit),
221-
sampled_coords = list(sampled_coords),
222216
fit_call = fit_call,
223217
model_obj_part = model_obj_part,
224218
model_obj_full = list(NULL)
@@ -356,7 +350,6 @@ evoland_db$set(
356350
model_family = best_models$model_family[i],
357351
model_params = list(best_models$model_params[[i]]),
358352
goodness_of_fit = list(best_models$goodness_of_fit[[i]]),
359-
sampled_coords = list(best_models$sampled_coords[[i]]),
360353
fit_call = best_models$fit_call[i],
361354
model_obj_part = list(best_models$model_obj_part[[i]]),
362355
model_obj_full = model_obj_full
@@ -393,7 +386,6 @@ validate.trans_models_t <- function(x, ...) {
393386
"model_family",
394387
"model_params",
395388
"goodness_of_fit",
396-
"sampled_coords",
397389
"fit_call",
398390
"model_obj_part",
399391
"model_obj_full"
@@ -410,7 +402,6 @@ validate.trans_models_t <- function(x, ...) {
410402
is.character(x[["model_family"]]),
411403
is.list(x[["model_params"]]),
412404
is.list(x[["goodness_of_fit"]]),
413-
is.list(x[["sampled_coords"]]),
414405
is.character(x[["fit_call"]]),
415406
is.list(x[["model_obj_part"]]),
416407
is.list(x[["model_obj_full"]]),

inst/tinytest/test_trans_models_t.R

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ expect_stdout(print(trans_models_empty), "Transition Models Table")
66
expect_equal(nrow(trans_models_empty), 0L)
77
expect_true(inherits(trans_models_empty, "trans_models_t"))
88

9-
# Test creation with data (updated structure with sampled_coords and fit_call)
9+
# Test creation with data
1010
trans_models_t <- as_trans_models_t(data.table::data.table(
1111
id_trans = 1L,
1212
model_family = "rf",
@@ -16,9 +16,6 @@ trans_models_t <- as_trans_models_t(data.table::data.table(
1616
goodness_of_fit = list(
1717
list(auc = 0.8, rmse = 0.15)
1818
),
19-
sampled_coords = list(
20-
data.table::data.table(id_coord = 1:10, id_period = rep(1L, 10))
21-
),
2219
fit_call = "fit_fun(data = data, result_col = \"result\")",
2320
model_obj_part = list(
2421
charToRaw("partial model data")
@@ -195,52 +192,35 @@ expect_message(
195192
gof_fun = gof_mock,
196193
sample_pct = 70,
197194
seed = 123,
198-
na_value = 0
195+
na_value = 0,
196+
other_param = "nonce"
199197
),
200198
"Fitting partial model"
201199
)
202-
203-
# test DB round trip
204-
expect_silent(db_tm$trans_models_t <- partial_models)
205-
expect_equivalent(db_tm$trans_models_t, partial_models)
206-
207-
expect_true(inherits(partial_models, "trans_models_t"))
208-
expect_true(nrow(partial_models) > 0L)
209-
expect_true(all(partial_models$id_trans > 0))
210-
211-
# Check that partial models are present
212-
expect_true(all(!vapply(partial_models$model_obj_part, is.null, logical(1))))
213-
214-
# Check that full models are NULL
215-
expect_true(all(vapply(partial_models$model_obj_full, is.null, logical(1))))
216-
217-
# Check that sampled_coords is present
218-
expect_true(all(!vapply(partial_models$sampled_coords, is.null, logical(1))))
219-
first_sampled <- partial_models$sampled_coords[[1]]
220-
expect_true(inherits(first_sampled, "data.table"))
221-
expect_true("id_coord" %in% names(first_sampled))
222-
expect_true("id_period" %in% names(first_sampled))
223-
224-
# Check that fit_call is present and is character string
225-
expect_true(all(nchar(partial_models$fit_call) > 0))
226-
expect_true(is.character(partial_models$fit_call))
227-
228-
# Check that fit_call contains the function name
229-
first_call <- partial_models$fit_call[1]
230-
expect_true(grepl("fit_mock_glm", first_call))
231-
expect_true(grepl("data.*result_col", first_call))
232-
233-
# Check that goodness_of_fit is populated
234-
first_gof <- partial_models$goodness_of_fit[[1]]
235-
expect_true(length(first_gof) > 0)
236-
expect_true("cor" %in% names(first_gof) || "mse" %in% names(first_gof))
200+
expect_equal(
201+
partial_models$fit_call[1],
202+
r"{fit_mock_glm(data = data, result_col = "result", other_param = "nonce")}"
203+
)
204+
expect_equal(
205+
partial_models$model_params[[1]],
206+
list(n_predictors = 3, n_train = 17, sample_pct = 70, other_param = "nonce")
207+
)
208+
expect_true(all(
209+
!vapply(partial_models$model_obj_part, is.null, logical(1))
210+
))
211+
expect_equal(
212+
partial_models$goodness_of_fit[[1]],
213+
list(cor = 0.6917245, mse = 0.1610296, n_test = 5),
214+
tolerance = 1e07
215+
)
237216

238217
# Test that model deserialization works
239218
first_model_part <- qs2::qs_deserialize(partial_models$model_obj_part[[1]])
240219
expect_true(inherits(first_model_part, "glm"))
241220

242-
# Note: Reproducibility with seed can be affected by RNG state from previous operations
243-
# Skipping reproducibility test for now
221+
# test DB round trip
222+
expect_silent(db_tm$trans_models_t <- partial_models)
223+
expect_equivalent(db_tm$trans_models_t, partial_models)
244224

245225
# Test fit_full_models
246226
expect_message(
@@ -257,7 +237,6 @@ expect_message(
257237
expect_silent(db_tm$trans_models_t <- full_models)
258238
expect_identical(db_tm$trans_models_t, full_models)
259239

260-
261240
expect_true(inherits(full_models, "trans_models_t"))
262241
expect_true(nrow(full_models) > 0L)
263242

@@ -430,7 +409,6 @@ expect_equal(
430409
"model_family",
431410
"model_params",
432411
"goodness_of_fit",
433-
"sampled_coords",
434412
"fit_call",
435413
"model_obj_part",
436414
"model_obj_full"

man/fit_ranger.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/trans_models_t.Rd

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)