Skip to content

Commit 8f70efd

Browse files
author
ercbk
committed
performance experiment initial commit
1 parent 73d1786 commit 8f70efd

File tree

8 files changed

+520
-0
lines changed

8 files changed

+520
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Create Hyperparameter grid list
2+
3+
4+
# input:
5+
# 1, size = number of rows
6+
# 2. algorithms = list of algorithm abbreviations
7+
# "rf" = Ranger Random Forest
8+
# "glmnet" = Elastic Net regression
9+
# "svm" = Support Vector Machines
10+
11+
pacman::p_load(dplyr)
12+
13+
create_grids <- function(algorithms, size = 100) {
14+
15+
# Elastic Net Regression
16+
17+
glm_params <- dials::grid_latin_hypercube(
18+
dials::mixture(),
19+
dials::penalty(),
20+
size = size
21+
)
22+
23+
# Random Forest
24+
25+
rf_params <- dials::grid_latin_hypercube(
26+
dials::mtry(range = c(3, 4)),
27+
dials::trees(range = c(200, 300)),
28+
size = size
29+
)
30+
31+
# Support Vector Machines
32+
33+
svm_params <- dials::grid_latin_hypercube(
34+
dials::cost(),
35+
dials::margin(),
36+
size = size
37+
)
38+
39+
grid_list <- purrr::map(algorithms, function(alg) {
40+
switch(alg,
41+
rf = rf_params -> alg_grid,
42+
glmnet = glm_params -> alg_grid,
43+
svm = svm_params -> alg_grid,
44+
infer:::stop_glue("{alg} grid not available."))
45+
alg_grid
46+
47+
}) %>%
48+
purrr::set_names(algorithms)
49+
}
50+
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Creates list of model functions
2+
3+
# input: list of algorithm abbreviations
4+
# "rf" = Ranger Random Forest
5+
# "glmnet" = Elastic Net regression
6+
# "svm" = Support Vector Machines
7+
8+
# output: list of model functions
9+
10+
pacman::p_load(dplyr)
11+
12+
create_models <- function(algorithms) {
13+
14+
# Random Forest
15+
16+
ranger_FUN <- function(params, analysis_set) {
17+
mtry <- params$mtry[[1]]
18+
trees <- params$trees[[1]]
19+
model <- ranger::ranger(y ~ ., data = analysis_set, mtry = mtry, num.trees = trees)
20+
model
21+
}
22+
23+
# Elastic Net Regression
24+
25+
glm_FUN <- function(params, analysis_set) {
26+
alpha <- params$mixture[[1]]
27+
lambda <- params$penalty[[1]]
28+
model <- parsnip::linear_reg(mixture = alpha, penalty = lambda) %>%
29+
parsnip::set_engine("glmnet") %>%
30+
generics::fit(y ~ ., data = analysis_set)
31+
model
32+
}
33+
34+
# Support Vector Machines
35+
36+
svm_FUN <- function(params, analysis_set) {
37+
cost <- params$cost[[1]]
38+
model <- kernlab::ksvm(y ~ ., data = analysis_set, C = cost)
39+
model
40+
}
41+
42+
mod_FUN_list <- purrr::map(algorithms, function(alg) {
43+
switch(alg,
44+
rf = ranger_FUN -> mod_fun,
45+
glmnet = glm_FUN -> mod_fun,
46+
svm = svm_FUN -> mod_fun,
47+
infer:::stop_glue("{alg} model function not available."))
48+
mod_fun
49+
50+
}) %>%
51+
purrr::set_names(algorithms)
52+
}
53+
54+
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# nested-cv data function
2+
3+
4+
5+
6+
create_ncv <- function(dat, repeats, method) {
7+
8+
attempt::stop_if_not(repeats, is.numeric, "repeats needs to be a numeric class")
9+
attempt::stop_if_not(method, is.character, "method needs to be a character class")
10+
11+
grid <- tidyr::crossing(dat, repeats)
12+
13+
if (method == "kj") {
14+
ncv_list <- purrr::map2(grid$dat, grid$repeats, function(dat, reps) {
15+
rsample::nested_cv(dat,
16+
outside = vfold_cv(v = 10, repeats = dynGet("reps")),
17+
inside = bootstraps(times = 25))
18+
})
19+
} else if (method == "raschka") {
20+
ncv_list <- purrr::map2(grid$dat, grid$repeats, function(dat, reps) {
21+
rsample::nested_cv(dat,
22+
outside = vfold_cv(v = 5, repeats = dynGet("reps")),
23+
inside = vfold_cv(v = 2))
24+
})
25+
} else {
26+
stop("Need to specify method as kj or raschka", call. = FALSE)
27+
}
28+
29+
return(ncv_list)
30+
}
31+
32+
33+
34+
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# inner loop tuning function
2+
3+
4+
5+
pacman::p_load(dplyr, furrr, data.table, dtplyr)
6+
7+
inner_tune <- function(ncv_dat, mod_FUN_list, params_list, error_FUN) {
8+
9+
# inputs params, model, and resample, calls model and error functions, outputs error
10+
mod_error <- function(params, mod_FUN, dat) {
11+
y_col <- ncol(dat$data)
12+
y_obs <- rsample::assessment(dat)[y_col]
13+
mod <- mod_FUN(params, rsample::analysis(dat))
14+
pred <- predict(mod, rsample::assessment(dat))
15+
if (!is.data.frame(pred)) {
16+
pred <- pred$predictions
17+
}
18+
error <- error_FUN(y_obs, pred)
19+
error
20+
}
21+
22+
# inputs resample, loops hyperparam grid values to model/error function, collects error value for hyperparam combo
23+
tune_over_params <- function(dat, mod_FUN, params) {
24+
params$error <- purrr::map_dbl(1:nrow(params), function(row) {
25+
params <- params[row,]
26+
mod_error(params, mod_FUN, dat)
27+
})
28+
params
29+
}
30+
31+
# inputs and sends fold's resamples to tuning function, collects and averages fold's error for each hyperparameter combo
32+
summarize_tune_results <- function(dat, mod_FUN, params) {
33+
# Return row-bound tibble that has the 25 bootstrap results
34+
param_names <- names(params)
35+
furrr::future_map_dfr(dat$splits, tune_over_params, mod_FUN, params, .progress = TRUE) %>%
36+
lazy_dt(., key_by = param_names) %>%
37+
# For each value of the tuning parameter, compute the
38+
# average <error> which is the inner bootstrap estimate.
39+
group_by_at(vars(all_of(param_names))) %>%
40+
summarize(mean_error = mean(error, na.rm = TRUE),
41+
sd_error = sd(error, na.rm = TRUE),
42+
n = length(error)) %>%
43+
as_tibble()
44+
}
45+
46+
tune_algorithms <- purrr::map2(mod_FUN_list, params_list, function(mod_FUN, params){
47+
tuning_results <- purrr::map(ncv_dat$inner_resamples, summarize_tune_results, mod_FUN, params)
48+
49+
# Choose best hyperparameter combos across all the resamples for each fold (e.g. 5 repeats 10 folds = 50 best hyperparam combos)
50+
best_hyper_vals <- tuning_results %>%
51+
purrr::map_df(function(dat) {
52+
dat %>%
53+
filter(mean_error == min(mean_error)) %>%
54+
arrange(sd_error) %>%
55+
slice(1)
56+
}) %>%
57+
select(all_of(names(params)))
58+
})
59+
}
60+
61+
62+
# chosen_hypervals <- inner_tune(ncv_dat = ncv_dat_list[[1]], mod_FUN_list = mod_FUN_list_ranger, params_list = params_list, error_FUN = error_FUN)
63+

performance-experiment/main.R

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
2+
3+
4+
5+
6+
source("performance-experiment/mlbench-data.R")
7+
source("performance-experiment/create-ncv.R")
8+
source("performance-experiment/create-models.R")
9+
source("performance-experiment/create-grids.R")
10+
source("performance-experiment/inner-tune.R")
11+
source("performance-experiment/outer-cv.R")
12+
source("performance-experiment/ncv-compare.R")
13+
14+
# options(error = function() {
15+
# library(RPushbullet)
16+
# pbPost("note", "Error", geterrmessage())
17+
# if(!interactive()) stop(geterrmessage())
18+
# })
19+
#
20+
#
21+
# library(tictoc)
22+
# tic()
23+
#
24+
#
25+
# pacman::p_load(RPushbullet, glue)
26+
27+
set.seed(2019)
28+
29+
plan(multiprocess)
30+
31+
method <- "raschka"
32+
# method <- "kj"
33+
algorithms <- list("glmnet", "rf")
34+
35+
# sample_sizes <- c(100, 800, 2000, 5000, 10000)
36+
# repeats <- seq(1:5)
37+
38+
sample_sizes <- 100
39+
repeats <- 1
40+
41+
# method or method list?
42+
43+
large_dat <- mlbench_data(n = 10^5, noise_sd = 1, seed = 2019)
44+
45+
simdat_list <- purrr::map(sample_sizes, ~mlbench_data(.x))
46+
47+
ncv_dat_list <- create_ncv(dat = simdat_list, repeats = repeats, method = method)
48+
49+
50+
error_FUN <- function(y_obs, y_hat){
51+
y_obs <- unlist(y_obs)
52+
y_hat <- unlist(y_hat)
53+
Metrics::mae(y_obs, y_hat)
54+
}
55+
56+
mod_FUN_list <- create_models(algorithms)
57+
58+
params_list <- create_grids(algorithms, size = 100)
59+
60+
ncv_results <- purrr::map2_dfr(ncv_dat_list, simdat_list, function(ncv_dat, sim_dat) {
61+
62+
best_hypervals_list <- inner_tune(
63+
ncv_dat = ncv_dat,
64+
mod_FUN_list = mod_FUN_list,
65+
params_list = params_list,
66+
error_FUN = error_FUN)
67+
68+
# model, mean, median, sd error, and parameter columns
69+
if (method == "raschka") {
70+
cv_stats <- outer_cv(
71+
ncv_dat = ncv_dat,
72+
best_hypervals_list = best_hypervals_list,
73+
mod_FUN_list = mod_FUN_list,
74+
error_FUN = error_FUN,
75+
method = method,
76+
train_dat = sim_dat,
77+
params_list = params_list)
78+
} else if (method == "kj") {
79+
cv_stats <- outer_cv(
80+
ncv_dat = ncv_dat,
81+
best_hypervals_list = best_hypervals_list,
82+
mod_FUN_list = mod_FUN_list,
83+
error_FUN = error_FUN,
84+
method = method)
85+
}
86+
87+
genl_perf_est <- ncv_compare(train_dat = sim_dat,
88+
large_dat = large_dat,
89+
cv_stats = cv_stats,
90+
mod_FUN_list = mod_FUN_list,
91+
params_list = params_list,
92+
error_FUN = error_FUN,
93+
method = method)
94+
95+
})
96+
97+
indices <- tidyr::crossing(sample_sizes, repeats)
98+
99+
perf_exp_results <- indices %>%
100+
bind_cols(ncv_results)
101+
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# create simulation data
2+
3+
# Inputs are 10 independent variables uniformly distributed on the interval [0,1], only 5 out of these 10 are actually used. Outputs are created according to the formula
4+
# y = 10 sin(π x1 x2) + 20 (x3 - 0.5)^2 + 10 x4 + 5 x5 + e
5+
6+
mlbench_data <- function(n, noise_sd = 1, seed = 2019) {
7+
set.seed(seed)
8+
tmp <- mlbench::mlbench.friedman1(n, sd = noise_sd)
9+
tmp <- cbind(tmp$x, tmp$y)
10+
tmp <- as.data.frame(tmp)
11+
names(tmp)[ncol(tmp)] <- "y"
12+
tmp
13+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# ncv_compare function
2+
3+
4+
# Chooses the best algorithm, fits best model on entire training set, predicts against large simulated data set
5+
6+
7+
ncv_compare <- function(train_dat, large_dat, cv_stats, mod_FUN_list, params_list, error_FUN, method) {
8+
9+
if (method == "kj") {
10+
# Choose alg with lowest avg error
11+
chosen_alg <- cv_stats %>%
12+
bind_rows(.id = "model") %>%
13+
filter(mean_error == min(mean_error)) %>%
14+
pull(model)
15+
16+
# Set inputs to chosen alg
17+
mod_FUN <- mod_FUN_list[[chosen_alg]]
18+
params <- cv_stats[[chosen_alg]] %>%
19+
select(names(params_list[[chosen_alg]]))
20+
21+
} else if (method == "raschka") {
22+
chosen_alg <- cv_stats %>%
23+
pull(model)
24+
mod_FUN <- mod_FUN_list[[chosen_alg]]
25+
params <- cv_stats %>%
26+
filter(model == chosen_alg) %>%
27+
select(names(params_list[[chosen_alg]]))
28+
}
29+
30+
fit <- mod_FUN(params, train_dat)
31+
32+
# fit <- mod_FUN(params, ncv_dat_list$sim_data[[1]])
33+
preds <- predict(fit, large_dat)
34+
if (!is.data.frame(preds)) {
35+
preds <- preds$predictions
36+
}
37+
38+
# calculate out-of-sample and retrieve nested-cv error
39+
y_col <- ncol(large_dat)
40+
y_obs <- large_dat[y_col]
41+
oos_error <- round(error_FUN(y_obs, preds), 5)
42+
43+
if (method == "kj") {
44+
ncv_error <- cv_stats[[chosen_alg]] %>%
45+
mutate(mean_error = round(mean_error, 5)) %>%
46+
pull(mean_error)
47+
} else if (method == "raschka") {
48+
ncv_error <- cv_stats %>%
49+
filter(model == chosen_alg) %>%
50+
mutate(mean_error = round(mean_error, 5)) %>%
51+
pull(mean_error)
52+
}
53+
54+
# delta (the difference between errors) is how well the ncv estimated generalization performance
55+
ncv_perf <- bind_cols(oos_error = oos_error, ncv_error = ncv_error) %>%
56+
mutate(method = method,
57+
delta_error = abs(oos_error - ncv_error),
58+
chosen_algorithm = chosen_alg) %>%
59+
bind_cols(params) %>%
60+
select(method, everything())
61+
62+
}
63+
64+
65+
66+
67+

0 commit comments

Comments
 (0)