@@ -19,10 +19,101 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
1919}
2020
2121# ------------------------------------------------------------------------------
22+ # min_grid generic - put here so that the generic shows up first in the man file
23+
24+ # ' Determine the minimum set of model fits
25+ # '
26+ # ' `min_grid` determines exactly what models should be fit in order to
27+ # ' evaluate the entire set of tuning parameter combinations. This is for
28+ # ' internal use only.
29+ # ' @param x A model specification.
30+ # ' @param grid A tibble with tuning parameter combinations.
31+ # ' @param ... Not currently used.
32+ # ' @return A tibble with the minimum tuning parameters to fit and an additional
33+ # ' list column with the parameter combinations used for prediction.
34+ # ' @keywords internal
35+ # ' @export
36+ min_grid <- function (x , grid , ... ) {
37+ # x is a `model_spec` object from parsnip
38+ # grid is a tibble of tuning parameter values with names
39+ # matching the parameter names.
40+ UseMethod(" min_grid" )
41+ }
42+
43+ # As an example, if we fit a boosted tree model and tune over
44+ # trees = 1:20 and min_n = c(20, 30)
45+ # we should only have to fit two models:
46+ #
47+ # trees = 20 & min_n = 20
48+ # trees = 20 & min_n = 30
49+ #
50+ # The logic related to how this "mini grid" gets made is model-specific.
51+ #
52+ # To get the full set of predictions, we need to know, for each of these two
53+ # models, what values of num_terms to give to the multi_predict() function.
54+ #
55+ # The current idea is to have a list column of the extra models for prediction.
56+ # For the example above:
57+ #
58+ # # A tibble: 2 x 3
59+ # trees min_n .submodels
60+ # <dbl> <dbl> <list>
61+ # 1 20 20 <named list [1]>
62+ # 2 20 30 <named list [1]>
63+ #
64+ # and the .submodels would both be
65+ #
66+ # list(trees = 1:19)
67+ #
68+ # There are a lot of other things to consider in future versions like grids
69+ # where there are multiple columns with the same name (maybe the results of
70+ # a recipe) and so on.
71+
72+ # ------------------------------------------------------------------------------
73+ # helper functions
74+
75+ # Template for model results that do no have the sub-model feature
76+ blank_submodels <- function (grid ) {
77+ grid %> %
78+ dplyr :: mutate(.submodels = map(1 : nrow(grid ), ~ list ()))
79+ }
80+
81+ get_fixed_args <- function (info ) {
82+ # Get non-sub-model columns to iterate over
83+ fixed_args <- info $ name [! info $ has_submodel ]
84+ }
85+
86+ get_submodel_info <- function (spec , grid ) {
87+ param_info <-
88+ get_from_env(paste0(class(spec )[1 ], " _args" )) %> %
89+ dplyr :: filter(engine == spec $ engine ) %> %
90+ dplyr :: select(name = parsnip , has_submodel )
91+
92+ # In case a recipe or other activity has grid parameter columns,
93+ # add those to the results
94+ grid_names <- names(grid )
95+ is_mod_param <- grid_names %in% param_info $ name
96+ if (any(! is_mod_param )) {
97+ param_info <-
98+ param_info %> %
99+ dplyr :: bind_rows(
100+ tibble :: tibble(name = grid_names [! is_mod_param ],
101+ has_submodel = FALSE )
102+ )
103+ }
104+ param_info %> % dplyr :: filter(name %in% grid_names )
105+ }
106+
107+
108+ # ------------------------------------------------------------------------------
109+ # nocov
22110
23111# ' @importFrom utils globalVariables
24112utils :: globalVariables(
25113 c(' .' , ' .label' , ' .pred' , ' .row' , ' data' , ' engine' , ' engine2' , ' group' ,
26114 ' lab' , ' original' , ' predicted_label' , ' prediction' , ' value' , ' type' ,
27- " neighbors" )
115+ " neighbors" , " .submodels" , " has_submodel" , " max_neighbor" , " max_penalty" ,
116+ " max_terms" , " max_tree" , " name" , " num_terms" , " penalty" , " trees" )
28117 )
118+
119+ # nocov end
0 commit comments