Skip to content

Commit 6d0a5e7

Browse files
authored
Merge pull request #196 from tidymodels/cran-0-0-3-changes
changes for CRAN 0.0.3 release
2 parents fe92ae9 + a8caca7 commit 6d0a5e7

File tree

91 files changed

+4448
-1013
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+4448
-1013
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Package: parsnip
2-
Version: 0.0.2.9000
2+
Version: 0.0.3
33
Title: A Common API to Modeling and Analysis Functions
44
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc).
55
Authors@R: c(

NAMESPACE

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ S3method(fit_xy,model_spec)
55
S3method(has_multi_predict,default)
66
S3method(has_multi_predict,model_fit)
77
S3method(has_multi_predict,workflow)
8+
S3method(min_grid,boost_tree)
9+
S3method(min_grid,linear_reg)
10+
S3method(min_grid,logistic_reg)
11+
S3method(min_grid,mars)
12+
S3method(min_grid,multinom_reg)
13+
S3method(min_grid,nearest_neighbor)
814
S3method(multi_predict,"_C5.0")
915
S3method(multi_predict,"_earth")
1016
S3method(multi_predict,"_elnet")
@@ -50,8 +56,11 @@ S3method(print,svm_rbf)
5056
S3method(translate,boost_tree)
5157
S3method(translate,decision_tree)
5258
S3method(translate,default)
59+
S3method(translate,linear_reg)
60+
S3method(translate,logistic_reg)
5361
S3method(translate,mars)
5462
S3method(translate,mlp)
63+
S3method(translate,multinom_reg)
5564
S3method(translate,nearest_neighbor)
5665
S3method(translate,rand_forest)
5766
S3method(translate,surv_reg)
@@ -104,6 +113,13 @@ export(linear_reg)
104113
export(logistic_reg)
105114
export(make_classes)
106115
export(mars)
116+
export(min_grid)
117+
export(min_grid.boost_tree)
118+
export(min_grid.linear_reg)
119+
export(min_grid.logistic_reg)
120+
export(min_grid.mars)
121+
export(min_grid.multinom_reg)
122+
export(min_grid.nearest_neighbor)
107123
export(mlp)
108124
export(model_printer)
109125
export(multi_predict)

NEWS.md

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
1-
# parsnip 0.0.2.9000
1+
# parsnip 0.0.3
2+
3+
Unplanned release based on CRAN requirements for Solaris.
24

35
## Breaking Changes
46

5-
* The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. The methods are detailed in `?get_model_env()` and the [package vignette for adding models](https://tidymodels.github.io/parsnip/articles/articles/Scratch.html).
6-
* The mode need to be declared for models that can be used for more than one mode prior to fitting and/or translation).
7+
* The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. The methods are detailed in `?get_model_env` and the [package vignette for adding models](https://tidymodels.github.io/parsnip/articles/articles/Scratch.html).
8+
9+
* The mode needs to be declared for models that can be used for more than one mode prior to fitting and/or translation.
10+
711
* For `surv_reg()`, the engine that uses the `survival` package is now called `survival` instead of `survreg`.
812

13+
* For `glmnet` models, the full regularization path is always fit regardless of the value given to `penalty`. Previously, the model was fit with passing `penalty` to `glmnet`'s `lambda` argument and the model could only make predictions at those specific values. [(#195)](https://github.com/tidymodels/parsnip/issues/195)
14+
915
## New Features
1016

1117
* `add_rowindex()` can create a column called `.row` to a data frame.
1218

1319
* If a computational engine is not explicitly set, a default will be used. Each default is documented on the corresponding model page. A warning is issued at fit time unless verbosity is zero.
14-
* `nearest_neighbor` gained a `multi_predict` method. The `multi_predict()` documentation is a little better organized.
20+
21+
* `nearest_neighbor()` gained a `multi_predict` method. The `multi_predict()` documentation is a little better organized.
22+
23+
* A suite of internal functions were added to help with upcoming model tuning features.
1524

1625

1726
# parsnip 0.0.2

R/aaa.R

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,102 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
1919
}
2020

2121
# ------------------------------------------------------------------------------
22+
# min_grid generic - put here so that the generic shows up first in the man file
23+
24+
#' Determine the minimum set of model fits
25+
#'
26+
#' `min_grid` determines exactly what models should be fit in order to
27+
#' evaluate the entire set of tuning parameter combinations. This is for
28+
#' internal use only and the API may change in the near future.
29+
#' @param x A model specification.
30+
#' @param grid A tibble with tuning parameter combinations.
31+
#' @param ... Not currently used.
32+
#' @return A tibble with the minimum tuning parameters to fit and an additional
33+
#' list column with the parameter combinations used for prediction.
34+
#' @keywords internal
35+
#' @export
36+
min_grid <- function(x, grid, ...) {
37+
# x is a `model_spec` object from parsnip
38+
# grid is a tibble of tuning parameter values with names
39+
# matching the parameter names.
40+
UseMethod("min_grid")
41+
}
42+
43+
# As an example, if we fit a boosted tree model and tune over
44+
# trees = 1:20 and min_n = c(20, 30)
45+
# we should only have to fit two models:
46+
#
47+
# trees = 20 & min_n = 20
48+
# trees = 20 & min_n = 30
49+
#
50+
# The logic related to how this "mini grid" gets made is model-specific.
51+
#
52+
# To get the full set of predictions, we need to know, for each of these two
53+
# models, what values of num_terms to give to the multi_predict() function.
54+
#
55+
# The current idea is to have a list column of the extra models for prediction.
56+
# For the example above:
57+
#
58+
# # A tibble: 2 x 3
59+
# trees min_n .submodels
60+
# <dbl> <dbl> <list>
61+
# 1 20 20 <named list [1]>
62+
# 2 20 30 <named list [1]>
63+
#
64+
# and the .submodels would both be
65+
#
66+
# list(trees = 1:19)
67+
#
68+
# There are a lot of other things to consider in future versions like grids
69+
# where there are multiple columns with the same name (maybe the results of
70+
# a recipe) and so on.
71+
72+
# ------------------------------------------------------------------------------
73+
# helper functions
74+
75+
# Template for model results that do no have the sub-model feature
76+
blank_submodels <- function(grid) {
77+
grid %>%
78+
dplyr::mutate(.submodels = map(1:nrow(grid), ~ list()))
79+
}
80+
81+
get_fixed_args <- function(info) {
82+
# Get non-sub-model columns to iterate over
83+
fixed_args <- info$name[!info$has_submodel]
84+
}
85+
86+
get_submodel_info <- function(spec, grid) {
87+
param_info <-
88+
get_from_env(paste0(class(spec)[1], "_args")) %>%
89+
dplyr::filter(engine == spec$engine) %>%
90+
dplyr::select(name = parsnip, has_submodel)
91+
92+
# In case a recipe or other activity has grid parameter columns,
93+
# add those to the results
94+
grid_names <- names(grid)
95+
is_mod_param <- grid_names %in% param_info$name
96+
if (any(!is_mod_param)) {
97+
param_info <-
98+
param_info %>%
99+
dplyr::bind_rows(
100+
tibble::tibble(name = grid_names[!is_mod_param],
101+
has_submodel = FALSE)
102+
)
103+
}
104+
param_info %>% dplyr::filter(name %in% grid_names)
105+
}
106+
107+
108+
# ------------------------------------------------------------------------------
109+
# nocov
22110

23111
#' @importFrom utils globalVariables
24112
utils::globalVariables(
25113
c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group',
26114
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
27-
"neighbors")
115+
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
116+
"max_terms", "max_tree", "name", "num_terms", "penalty", "trees",
117+
"sub_neighbors")
28118
)
119+
120+
# nocov end

R/boost_tree.R

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,41 @@ C50_by_tree <- function(tree, object, new_data, type, ...) {
514514
pred[, c(".row", "trees", nms)]
515515
}
516516

517+
# ------------------------------------------------------------------------------
518+
519+
#' @export
520+
#' @export min_grid.boost_tree
521+
#' @rdname min_grid
522+
min_grid.boost_tree <- function(x, grid, ...) {
523+
grid_names <- names(grid)
524+
param_info <- get_submodel_info(x, grid)
525+
526+
# No ability to do submodels? Finish here:
527+
if (!any(param_info$has_submodel)) {
528+
return(blank_submodels(grid))
529+
}
530+
531+
fixed_args <- get_fixed_args(param_info)
532+
533+
# For boosted trees, fit the model with the most trees (conditional on the
534+
# other parameters) so that you can do predictions on the smaller models.
535+
fit_only <-
536+
grid %>%
537+
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
538+
dplyr::summarize(trees = max(trees, na.rm = TRUE)) %>%
539+
dplyr::ungroup()
540+
541+
# Add a column .submodels that is a list with what should be predicted
542+
# by `multi_predict()` (assuming `predict()` has already been executed
543+
# on the original value of 'trees')
544+
min_grid_df <-
545+
dplyr::full_join(fit_only %>% rename(max_tree = trees), grid, by = fixed_args) %>%
546+
dplyr::filter(trees != max_tree) %>%
547+
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
548+
dplyr::summarize(.submodels = list(list(trees = trees))) %>%
549+
dplyr::ungroup() %>%
550+
dplyr::full_join(fit_only, grid, by = fixed_args)
551+
552+
min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
553+
}
554+

R/linear_reg.R

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@
6868
#'
6969
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")}
7070
#'
71-
#' When using `glmnet` models, there is the option to pass
72-
#' multiple values (or no values) to the `penalty` argument. This
73-
#' can have an effect on the model object results. When using the
71+
#' For `glmnet` models, the full regularization path is always fit regardless
72+
#' of the value given to `penalty`. Also, there is the option to pass
73+
#' multiple values (or no values) to the `penalty` argument. When using the
7474
#' `predict()` method in these cases, the return value depends on
7575
#' the value of `penalty`. When using `predict()`, only a single
7676
#' value of the penalty can be used. When predicting on multiple
@@ -138,6 +138,23 @@ print.linear_reg <- function(x, ...) {
138138
invisible(x)
139139
}
140140

141+
142+
#' @export
143+
translate.linear_reg <- function(x, engine = x$engine, ...) {
144+
x <- translate.default(x, engine, ...)
145+
146+
if (engine == "glmnet") {
147+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
148+
x$method$fit$args$lambda <- NULL
149+
# Since the `fit` infomration is gone for the penalty, we need to have an
150+
# evaludated value for the parameter.
151+
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
152+
}
153+
154+
x
155+
}
156+
157+
141158
# ------------------------------------------------------------------------------
142159

143160
#' @inheritParams update.boost_tree
@@ -274,6 +291,11 @@ predict._elnet <-
274291
if (any(names(enquos(...)) == "newdata"))
275292
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
276293

294+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
295+
if (is.null(penalty) & !is.null(object$spec$args$penalty)) {
296+
penalty <- object$spec$args$penalty
297+
}
298+
277299
object$spec$args$penalty <- check_penalty(penalty, object, multi)
278300

279301
object$spec <- eval_args(object$spec)
@@ -314,7 +336,12 @@ multi_predict._elnet <-
314336
object$spec <- eval_args(object$spec)
315337

316338
if (is.null(penalty)) {
317-
penalty <- object$fit$lambda
339+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
340+
if (!is.null(object$spec$args$penalty)) {
341+
penalty <- object$spec$args$penalty
342+
} else {
343+
penalty <- object$fit$lambda
344+
}
318345
}
319346

320347
pred <- predict._elnet(object, new_data = new_data, type = "raw",
@@ -332,3 +359,37 @@ multi_predict._elnet <-
332359
names(pred) <- NULL
333360
tibble(.pred = pred)
334361
}
362+
363+
364+
# ------------------------------------------------------------------------------
365+
366+
#' @export
367+
#' @export min_grid.linear_reg
368+
#' @rdname min_grid
369+
min_grid.linear_reg <- function(x, grid, ...) {
370+
371+
grid_names <- names(grid)
372+
param_info <- get_submodel_info(x, grid)
373+
374+
if (!any(param_info$has_submodel)) {
375+
return(blank_submodels(grid))
376+
}
377+
378+
fixed_args <- get_fixed_args(param_info)
379+
380+
fit_only <-
381+
grid %>%
382+
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
383+
dplyr::summarize(penalty = max(penalty, na.rm = TRUE)) %>%
384+
dplyr::ungroup()
385+
386+
min_grid_df <-
387+
dplyr::full_join(fit_only %>% rename(max_penalty = penalty), grid, by = fixed_args) %>%
388+
dplyr::filter(penalty != max_penalty) %>%
389+
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
390+
dplyr::summarize(.submodels = list(list(penalty = penalty))) %>%
391+
dplyr::ungroup() %>%
392+
dplyr::full_join(fit_only, grid, by = fixed_args)
393+
394+
min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
395+
}

0 commit comments

Comments
 (0)