|
| 1 | +# inner loop tuning function |
| 2 | + |
| 3 | + |
| 4 | + |
| 5 | +pacman::p_load(dplyr, furrr, data.table, dtplyr) |
| 6 | + |
| 7 | +inner_tune <- function(ncv_dat, mod_FUN_list, params_list, error_FUN) { |
| 8 | + |
| 9 | + # inputs params, model, and resample, calls model and error functions, outputs error |
| 10 | + mod_error <- function(params, mod_FUN, dat) { |
| 11 | + y_col <- ncol(dat$data) |
| 12 | + y_obs <- rsample::assessment(dat)[y_col] |
| 13 | + mod <- mod_FUN(params, rsample::analysis(dat)) |
| 14 | + pred <- predict(mod, rsample::assessment(dat)) |
| 15 | + if (!is.data.frame(pred)) { |
| 16 | + pred <- pred$predictions |
| 17 | + } |
| 18 | + error <- error_FUN(y_obs, pred) |
| 19 | + error |
| 20 | + } |
| 21 | + |
| 22 | + # inputs resample, loops hyperparam grid values to model/error function, collects error value for hyperparam combo |
| 23 | + tune_over_params <- function(dat, mod_FUN, params) { |
| 24 | + params$error <- purrr::map_dbl(1:nrow(params), function(row) { |
| 25 | + params <- params[row,] |
| 26 | + mod_error(params, mod_FUN, dat) |
| 27 | + }) |
| 28 | + params |
| 29 | + } |
| 30 | + |
| 31 | + # inputs and sends fold's resamples to tuning function, collects and averages fold's error for each hyperparameter combo |
| 32 | + summarize_tune_results <- function(dat, mod_FUN, params) { |
| 33 | + # Return row-bound tibble that has the 25 bootstrap results |
| 34 | + param_names <- names(params) |
| 35 | + furrr::future_map_dfr(dat$splits, tune_over_params, mod_FUN, params, .progress = TRUE) %>% |
| 36 | + lazy_dt(., key_by = param_names) %>% |
| 37 | + # For each value of the tuning parameter, compute the |
| 38 | + # average <error> which is the inner bootstrap estimate. |
| 39 | + group_by_at(vars(all_of(param_names))) %>% |
| 40 | + summarize(mean_error = mean(error, na.rm = TRUE), |
| 41 | + sd_error = sd(error, na.rm = TRUE), |
| 42 | + n = length(error)) %>% |
| 43 | + as_tibble() |
| 44 | + } |
| 45 | + |
| 46 | + tune_algorithms <- purrr::map2(mod_FUN_list, params_list, function(mod_FUN, params){ |
| 47 | + tuning_results <- purrr::map(ncv_dat$inner_resamples, summarize_tune_results, mod_FUN, params) |
| 48 | + |
| 49 | + # Choose best hyperparameter combos across all the resamples for each fold (e.g. 5 repeats 10 folds = 50 best hyperparam combos) |
| 50 | + best_hyper_vals <- tuning_results %>% |
| 51 | + purrr::map_df(function(dat) { |
| 52 | + dat %>% |
| 53 | + filter(mean_error == min(mean_error)) %>% |
| 54 | + arrange(sd_error) %>% |
| 55 | + slice(1) |
| 56 | + }) %>% |
| 57 | + select(all_of(names(params))) |
| 58 | + }) |
| 59 | +} |
| 60 | + |
| 61 | + |
| 62 | +# chosen_hypervals <- inner_tune(ncv_dat = ncv_dat_list[[1]], mod_FUN_list = mod_FUN_list_ranger, params_list = params_list, error_FUN = error_FUN) |
| 63 | + |
0 commit comments