Skip to content

Commit 7571ff9

Browse files
committed
code to compute the minimum tuning grid
1 parent 45abcf6 commit 7571ff9

16 files changed

+514
-9
lines changed

NAMESPACE

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ S3method(fit_xy,model_spec)
55
S3method(has_multi_predict,default)
66
S3method(has_multi_predict,model_fit)
77
S3method(has_multi_predict,workflow)
8+
S3method(min_grid,boost_tree)
9+
S3method(min_grid,linear_reg)
10+
S3method(min_grid,logistic_reg)
11+
S3method(min_grid,mars)
12+
S3method(min_grid,multinom_reg)
13+
S3method(min_grid,nearest_neighbor)
814
S3method(multi_predict,"_C5.0")
915
S3method(multi_predict,"_earth")
1016
S3method(multi_predict,"_elnet")
@@ -104,6 +110,13 @@ export(linear_reg)
104110
export(logistic_reg)
105111
export(make_classes)
106112
export(mars)
113+
export(min_grid)
114+
export(min_grid.boost_tree)
115+
export(min_grid.linear_reg)
116+
export(min_grid.logistic_reg)
117+
export(min_grid.mars)
118+
export(min_grid.multinom_reg)
119+
export(min_grid.nearest_neighbor)
107120
export(mlp)
108121
export(model_printer)
109122
export(multi_predict)

R/aaa.R

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,101 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
1919
}
2020

2121
# ------------------------------------------------------------------------------
22+
# min_grid generic - put here so that the generic shows up first in the man file
23+
24+
#' Determine the minimum set of model fits
25+
#'
26+
#' `min_grid` determines exactly what models should be fit in order to
27+
#' evaluate the entire set of tuning parameter combinations. This is for
28+
#' internal use only.
29+
#' @param x A model specification.
30+
#' @param grid A tibble with tuning parameter combinations.
31+
#' @param ... Not currently used.
32+
#' @return A tibble with the minimum tuning parameters to fit and an additional
33+
#' list column with the parameter combinations used for prediction.
34+
#' @keywords internal
35+
#' @export
36+
min_grid <- function(x, grid, ...) {
37+
# x is a `model_spec` object from parsnip
38+
# grid is a tibble of tuning parameter values with names
39+
# matching the parameter names.
40+
UseMethod("min_grid")
41+
}
42+
43+
# As an example, if we fit a boosted tree model and tune over
44+
# trees = 1:20 and min_n = c(20, 30)
45+
# we should only have to fit two models:
46+
#
47+
# trees = 20 & min_n = 20
48+
# trees = 20 & min_n = 30
49+
#
50+
# The logic related to how this "mini grid" gets made is model-specific.
51+
#
52+
# To get the full set of predictions, we need to know, for each of these two
53+
# models, what values of num_terms to give to the multi_predict() function.
54+
#
55+
# The current idea is to have a list column of the extra models for prediction.
56+
# For the example above:
57+
#
58+
# # A tibble: 2 x 3
59+
# trees min_n .submodels
60+
# <dbl> <dbl> <list>
61+
# 1 20 20 <named list [1]>
62+
# 2 20 30 <named list [1]>
63+
#
64+
# and the .submodels would both be
65+
#
66+
# list(trees = 1:19)
67+
#
68+
# There are a lot of other things to consider in future versions like grids
69+
# where there are multiple columns with the same name (maybe the results of
70+
# a recipe) and so on.
71+
72+
# ------------------------------------------------------------------------------
73+
# helper functions
74+
75+
# Template for model results that do no have the sub-model feature
76+
blank_submodels <- function(grid) {
77+
grid %>%
78+
dplyr::mutate(.submodels = map(1:nrow(grid), ~ list()))
79+
}
80+
81+
get_fixed_args <- function(info) {
82+
# Get non-sub-model columns to iterate over
83+
fixed_args <- info$name[!info$has_submodel]
84+
}
85+
86+
get_submodel_info <- function(spec, grid) {
87+
param_info <-
88+
get_from_env(paste0(class(spec)[1], "_args")) %>%
89+
dplyr::filter(engine == spec$engine) %>%
90+
dplyr::select(name = parsnip, has_submodel)
91+
92+
# In case a recipe or other activity has grid parameter columns,
93+
# add those to the results
94+
grid_names <- names(grid)
95+
is_mod_param <- grid_names %in% param_info$name
96+
if (any(!is_mod_param)) {
97+
param_info <-
98+
param_info %>%
99+
dplyr::bind_rows(
100+
tibble::tibble(name = grid_names[!is_mod_param],
101+
has_submodel = FALSE)
102+
)
103+
}
104+
param_info %>% dplyr::filter(name %in% grid_names)
105+
}
106+
107+
108+
# ------------------------------------------------------------------------------
109+
# nocov
22110

23111
#' @importFrom utils globalVariables
24112
utils::globalVariables(
25113
c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group',
26114
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
27-
"neighbors")
115+
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
116+
"max_terms", "max_tree", "name", "num_terms", "penalty", "trees")
28117
)
118+
119+
# nocov end

R/boost_tree.R

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,41 @@ C50_by_tree <- function(tree, object, new_data, type, ...) {
514514
pred[, c(".row", "trees", nms)]
515515
}
516516

517+
# ------------------------------------------------------------------------------
518+
519+
#' @export
520+
#' @export min_grid.boost_tree
521+
#' @rdname min_grid
522+
min_grid.boost_tree <- function(x, grid, ...) {
523+
grid_names <- names(grid)
524+
param_info <- get_submodel_info(x, grid)
525+
526+
# No ability to do submodels? Finish here:
527+
if (!any(param_info$has_submodel)) {
528+
return(blank_submodels(grid))
529+
}
530+
531+
fixed_args <- get_fixed_args(param_info)
532+
533+
# For boosted trees, fit the model with the most trees (conditional on the
534+
# other parameters) so that you can do predictions on the smaller models.
535+
fit_only <-
536+
grid %>%
537+
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
538+
dplyr::summarize(trees = max(trees, na.rm = TRUE)) %>%
539+
dplyr::ungroup()
540+
541+
# Add a column .submodels that is a list with what should be predicted
542+
# by `multi_predict()` (assuming `predict()` has already been executed
543+
# on the original value of 'trees')
544+
min_grid_df <-
545+
dplyr::full_join(fit_only %>% rename(max_tree = trees), grid, by = fixed_args) %>%
546+
dplyr::filter(trees != max_tree) %>%
547+
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
548+
dplyr::summarize(.submodels = list(list(trees = trees))) %>%
549+
dplyr::ungroup() %>%
550+
dplyr::full_join(fit_only, grid, by = fixed_args)
551+
552+
min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
553+
}
554+

R/linear_reg.R

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,37 @@ multi_predict._elnet <-
332332
names(pred) <- NULL
333333
tibble(.pred = pred)
334334
}
335+
336+
337+
# ------------------------------------------------------------------------------
338+
339+
#' @export
340+
#' @export min_grid.linear_reg
341+
#' @rdname min_grid
342+
min_grid.linear_reg <- function(x, grid, ...) {
343+
344+
grid_names <- names(grid)
345+
param_info <- get_submodel_info(x, grid)
346+
347+
if (!any(param_info$has_submodel)) {
348+
return(blank_submodels(grid))
349+
}
350+
351+
fixed_args <- get_fixed_args(param_info)
352+
353+
fit_only <-
354+
grid %>%
355+
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
356+
dplyr::summarize(penalty = max(penalty, na.rm = TRUE)) %>%
357+
dplyr::ungroup()
358+
359+
min_grid_df <-
360+
dplyr::full_join(fit_only %>% rename(max_penalty = penalty), grid, by = fixed_args) %>%
361+
dplyr::filter(penalty != max_penalty) %>%
362+
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
363+
dplyr::summarize(.submodels = list(list(penalty = penalty))) %>%
364+
dplyr::ungroup() %>%
365+
dplyr::full_join(fit_only, grid, by = fixed_args)
366+
367+
min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
368+
}

R/logistic_reg.R

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ organize_glmnet_prob <- function(x, object) {
262262
# ------------------------------------------------------------------------------
263263

264264
#' @export
265-
predict._lognet <- function (object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
265+
predict._lognet <- function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
266266
if (any(names(enquos(...)) == "newdata"))
267267
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
268268

@@ -330,7 +330,7 @@ multi_predict._lognet <-
330330

331331

332332
#' @export
333-
predict_class._lognet <- function (object, new_data, ...) {
333+
predict_class._lognet <- function(object, new_data, ...) {
334334
if (any(names(enquos(...)) == "newdata"))
335335
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
336336

@@ -339,7 +339,7 @@ predict_class._lognet <- function (object, new_data, ...) {
339339
}
340340

341341
#' @export
342-
predict_classprob._lognet <- function (object, new_data, ...) {
342+
predict_classprob._lognet <- function(object, new_data, ...) {
343343
if (any(names(enquos(...)) == "newdata"))
344344
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
345345

@@ -348,11 +348,18 @@ predict_classprob._lognet <- function (object, new_data, ...) {
348348
}
349349

350350
#' @export
351-
predict_raw._lognet <- function (object, new_data, opts = list(), ...) {
351+
predict_raw._lognet <- function(object, new_data, opts = list(), ...) {
352352
if (any(names(enquos(...)) == "newdata"))
353353
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
354354

355355
object$spec <- eval_args(object$spec)
356356
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
357357
}
358358

359+
360+
# ------------------------------------------------------------------------------
361+
362+
#' @export
363+
#' @export min_grid.logistic_reg
364+
#' @rdname min_grid
365+
min_grid.logistic_reg <- min_grid.linear_reg

R/mars.R

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,36 @@ earth_by_terms <- function(num_terms, object, new_data, type, ...) {
262262
pred[[".row"]] <- 1:nrow(new_data)
263263
pred[, c(".row", "num_terms", nms)]
264264
}
265+
266+
# ------------------------------------------------------------------------------
267+
268+
#' @export
269+
#' @export min_grid.mars
270+
#' @rdname min_grid
271+
min_grid.mars <- function(x, grid, ...) {
272+
273+
grid_names <- names(grid)
274+
param_info <- get_submodel_info(x, grid)
275+
276+
if (!any(param_info$has_submodel)) {
277+
return(blank_submodels(grid))
278+
}
279+
280+
fixed_args <- get_fixed_args(param_info)
281+
282+
fit_only <-
283+
grid %>%
284+
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
285+
dplyr::summarize(num_terms = max(num_terms, na.rm = TRUE)) %>%
286+
dplyr::ungroup()
287+
288+
min_grid_df <-
289+
dplyr::full_join(fit_only %>% rename(max_terms = num_terms), grid, by = fixed_args) %>%
290+
dplyr::filter(num_terms != max_terms) %>%
291+
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
292+
dplyr::summarize(.submodels = list(list(num_terms = num_terms))) %>%
293+
dplyr::ungroup() %>%
294+
dplyr::full_join(fit_only, grid, by = fixed_args)
295+
296+
min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
297+
}

R/multinom_reg.R

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ multi_predict._multnet <-
249249
if (is.null(type))
250250
type <- "class"
251251
if (!(type %in% c("class", "prob", "link", "raw"))) {
252-
stop ("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE)
252+
stop("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE)
253253
}
254254
if (type == "prob")
255255
dots$type <- "response"
@@ -290,19 +290,19 @@ multi_predict._multnet <-
290290
}
291291

292292
#' @export
293-
predict_class._multnet <- function (object, new_data, ...) {
293+
predict_class._multnet <- function(object, new_data, ...) {
294294
object$spec <- eval_args(object$spec)
295295
predict_class.model_fit(object, new_data = new_data, ...)
296296
}
297297

298298
#' @export
299-
predict_classprob._multnet <- function (object, new_data, ...) {
299+
predict_classprob._multnet <- function(object, new_data, ...) {
300300
object$spec <- eval_args(object$spec)
301301
predict_classprob.model_fit(object, new_data = new_data, ...)
302302
}
303303

304304
#' @export
305-
predict_raw._multnet <- function (object, new_data, opts = list(), ...) {
305+
predict_raw._multnet <- function(object, new_data, opts = list(), ...) {
306306
object$spec <- eval_args(object$spec)
307307
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
308308
}
@@ -323,3 +323,10 @@ check_glmnet_lambda <- function(dat, object) {
323323
dat
324324
}
325325

326+
327+
# ------------------------------------------------------------------------------
328+
329+
#' @export
330+
#' @export min_grid.multinom_reg
331+
#' @rdname min_grid
332+
min_grid.multinom_reg <- min_grid.linear_reg

R/nearest_neighbor.R

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,36 @@ knn_by_k <- function(k, object, new_data, type, ...) {
218218
dplyr::mutate(neighbors = k, .row = dplyr::row_number()) %>%
219219
dplyr::select(.row, neighbors, dplyr::starts_with(".pred"))
220220
}
221+
222+
# ------------------------------------------------------------------------------
223+
224+
#' @export
225+
#' @export min_grid.nearest_neighbor
226+
#' @rdname min_grid
227+
min_grid.nearest_neighbor <- function(x, grid, ...) {
228+
229+
grid_names <- names(grid)
230+
param_info <- get_submodel_info(x, grid)
231+
232+
if (!any(param_info$has_submodel)) {
233+
return(blank_submodels(grid))
234+
}
235+
236+
fixed_args <- get_fixed_args(param_info)
237+
238+
fit_only <-
239+
grid %>%
240+
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
241+
dplyr::summarize(neighbors = max(neighbors, na.rm = TRUE)) %>%
242+
dplyr::ungroup()
243+
244+
min_grid_df <-
245+
dplyr::full_join(fit_only %>% rename(max_neighbor = neighbors), grid, by = fixed_args) %>%
246+
dplyr::filter(neighbors != max_neighbor) %>%
247+
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
248+
dplyr::summarize(.submodels = list(list(neighbors = neighbors))) %>%
249+
dplyr::ungroup() %>%
250+
dplyr::full_join(fit_only, grid, by = fixed_args)
251+
252+
min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
253+
}

0 commit comments

Comments
 (0)