Skip to content

Commit 7b36906

Browse files
committed
Merge branch 'master' into rlang-tibble-updates
2 parents 42ca158 + 6d0a5e7 commit 7b36906

File tree

110 files changed

+4918
-1079
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

110 files changed

+4918
-1079
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Package: parsnip
2-
Version: 0.0.2.9000
2+
Version: 0.0.3
33
Title: A Common API to Modeling and Analysis Functions
44
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc).
55
Authors@R: c(

NAMESPACE

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,26 @@
22

33
S3method(fit,model_spec)
44
S3method(fit_xy,model_spec)
5+
S3method(has_multi_predict,default)
6+
S3method(has_multi_predict,model_fit)
7+
S3method(has_multi_predict,workflow)
8+
S3method(min_grid,boost_tree)
9+
S3method(min_grid,linear_reg)
10+
S3method(min_grid,logistic_reg)
11+
S3method(min_grid,mars)
12+
S3method(min_grid,multinom_reg)
13+
S3method(min_grid,nearest_neighbor)
514
S3method(multi_predict,"_C5.0")
615
S3method(multi_predict,"_earth")
716
S3method(multi_predict,"_elnet")
817
S3method(multi_predict,"_lognet")
918
S3method(multi_predict,"_multnet")
19+
S3method(multi_predict,"_train.kknn")
1020
S3method(multi_predict,"_xgb.Booster")
1121
S3method(multi_predict,default)
22+
S3method(multi_predict_args,default)
23+
S3method(multi_predict_args,model_fit)
24+
S3method(multi_predict_args,workflow)
1225
S3method(nullmodel,default)
1326
S3method(predict,"_elnet")
1427
S3method(predict,"_lognet")
@@ -43,8 +56,11 @@ S3method(print,svm_rbf)
4356
S3method(translate,boost_tree)
4457
S3method(translate,decision_tree)
4558
S3method(translate,default)
59+
S3method(translate,linear_reg)
60+
S3method(translate,logistic_reg)
4661
S3method(translate,mars)
4762
S3method(translate,mlp)
63+
S3method(translate,multinom_reg)
4864
S3method(translate,nearest_neighbor)
4965
S3method(translate,rand_forest)
5066
S3method(translate,surv_reg)
@@ -91,14 +107,23 @@ export(get_fit)
91107
export(get_from_env)
92108
export(get_model_env)
93109
export(get_pred_type)
110+
export(has_multi_predict)
94111
export(keras_mlp)
95112
export(linear_reg)
96113
export(logistic_reg)
97114
export(make_classes)
98115
export(mars)
116+
export(min_grid)
117+
export(min_grid.boost_tree)
118+
export(min_grid.linear_reg)
119+
export(min_grid.logistic_reg)
120+
export(min_grid.mars)
121+
export(min_grid.multinom_reg)
122+
export(min_grid.nearest_neighbor)
99123
export(mlp)
100124
export(model_printer)
101125
export(multi_predict)
126+
export(multi_predict_args)
102127
export(multinom_reg)
103128
export(nearest_neighbor)
104129
export(null_model)
@@ -210,4 +235,5 @@ importFrom(utils,capture.output)
210235
importFrom(utils,getFromNamespace)
211236
importFrom(utils,globalVariables)
212237
importFrom(utils,head)
238+
importFrom(utils,methods)
213239
importFrom(vctrs,vec_unique)

NEWS.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
1-
# parsnip 0.0.2.9000
1+
# parsnip 0.0.3
2+
3+
Unplanned release based on CRAN requirements for Solaris.
24

35
## Breaking Changes
46

5-
* The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. The methods are detailed in `?get_model_env()` and the [package vignette for adding models](https://tidymodels.github.io/parsnip/articles/articles/Scratch.html).
6-
* The mode need to be declared for models that can be used for more than one mode prior to fitting and/or translation).
7+
* The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. The methods are detailed in `?get_model_env` and the [package vignette for adding models](https://tidymodels.github.io/parsnip/articles/articles/Scratch.html).
8+
9+
* The mode needs to be declared for models that can be used for more than one mode prior to fitting and/or translation.
10+
711
* For `surv_reg()`, the engine that uses the `survival` package is now called `survival` instead of `survreg`.
812

13+
* For `glmnet` models, the full regularization path is always fit regardless of the value given to `penalty`. Previously, the model was fit with passing `penalty` to `glmnet`'s `lambda` argument and the model could only make predictions at those specific values. [(#195)](https://github.com/tidymodels/parsnip/issues/195)
14+
915
## New Features
1016

1117
* `add_rowindex()` can create a column called `.row` to a data frame.
1218

1319
* If a computational engine is not explicitly set, a default will be used. Each default is documented on the corresponding model page. A warning is issued at fit time unless verbosity is zero.
1420

21+
* `nearest_neighbor()` gained a `multi_predict` method. The `multi_predict()` documentation is a little better organized.
22+
23+
* A suite of internal functions were added to help with upcoming model tuning features.
24+
1525

1626
# parsnip 0.0.2
1727

R/aaa.R

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,102 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
2121
}
2222

2323
# ------------------------------------------------------------------------------
24+
# min_grid generic - put here so that the generic shows up first in the man file
25+
26+
#' Determine the minimum set of model fits
27+
#'
28+
#' `min_grid` determines exactly what models should be fit in order to
29+
#' evaluate the entire set of tuning parameter combinations. This is for
30+
#' internal use only and the API may change in the near future.
31+
#' @param x A model specification.
32+
#' @param grid A tibble with tuning parameter combinations.
33+
#' @param ... Not currently used.
34+
#' @return A tibble with the minimum tuning parameters to fit and an additional
35+
#' list column with the parameter combinations used for prediction.
36+
#' @keywords internal
37+
#' @export
38+
min_grid <- function(x, grid, ...) {
39+
# x is a `model_spec` object from parsnip
40+
# grid is a tibble of tuning parameter values with names
41+
# matching the parameter names.
42+
UseMethod("min_grid")
43+
}
44+
45+
# As an example, if we fit a boosted tree model and tune over
46+
# trees = 1:20 and min_n = c(20, 30)
47+
# we should only have to fit two models:
48+
#
49+
# trees = 20 & min_n = 20
50+
# trees = 20 & min_n = 30
51+
#
52+
# The logic related to how this "mini grid" gets made is model-specific.
53+
#
54+
# To get the full set of predictions, we need to know, for each of these two
55+
# models, what values of num_terms to give to the multi_predict() function.
56+
#
57+
# The current idea is to have a list column of the extra models for prediction.
58+
# For the example above:
59+
#
60+
# # A tibble: 2 x 3
61+
# trees min_n .submodels
62+
# <dbl> <dbl> <list>
63+
# 1 20 20 <named list [1]>
64+
# 2 20 30 <named list [1]>
65+
#
66+
# and the .submodels would both be
67+
#
68+
# list(trees = 1:19)
69+
#
70+
# There are a lot of other things to consider in future versions like grids
71+
# where there are multiple columns with the same name (maybe the results of
72+
# a recipe) and so on.
73+
74+
# ------------------------------------------------------------------------------
75+
# helper functions
76+
77+
# Template for model results that do no have the sub-model feature
78+
blank_submodels <- function(grid) {
79+
grid %>%
80+
dplyr::mutate(.submodels = map(1:nrow(grid), ~ list()))
81+
}
82+
83+
get_fixed_args <- function(info) {
84+
# Get non-sub-model columns to iterate over
85+
fixed_args <- info$name[!info$has_submodel]
86+
}
87+
88+
get_submodel_info <- function(spec, grid) {
89+
param_info <-
90+
get_from_env(paste0(class(spec)[1], "_args")) %>%
91+
dplyr::filter(engine == spec$engine) %>%
92+
dplyr::select(name = parsnip, has_submodel)
93+
94+
# In case a recipe or other activity has grid parameter columns,
95+
# add those to the results
96+
grid_names <- names(grid)
97+
is_mod_param <- grid_names %in% param_info$name
98+
if (any(!is_mod_param)) {
99+
param_info <-
100+
param_info %>%
101+
dplyr::bind_rows(
102+
tibble::tibble(name = grid_names[!is_mod_param],
103+
has_submodel = FALSE)
104+
)
105+
}
106+
param_info %>% dplyr::filter(name %in% grid_names)
107+
}
108+
109+
110+
# ------------------------------------------------------------------------------
111+
# nocov
24112

25113
#' @importFrom utils globalVariables
26114
utils::globalVariables(
27115
c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group',
28-
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type')
116+
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
117+
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
118+
"max_terms", "max_tree", "name", "num_terms", "penalty", "trees",
119+
"sub_neighbors")
29120
)
121+
122+
# nocov end

R/aaa_multi_predict.R

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Define a generic to make multiple predictions for the same model object ------
2+
3+
#' Model predictions across many sub-models
4+
#'
5+
#' For some models, predictions can be made on sub-models in the model object.
6+
#' @param object A `model_fit` object.
7+
#' @param new_data A rectangular data object, such as a data frame.
8+
#' @param type A single character value or `NULL`. Possible values
9+
#' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
10+
#' or "raw". When `NULL`, `predict()` will choose an appropriate value
11+
#' based on the model's mode.
12+
#' @param ... Optional arguments to pass to `predict.model_fit(type = "raw")`
13+
#' such as `type`.
14+
#' @return A tibble with the same number of rows as the data being predicted.
15+
#' Mostly likely, there is a list-column named `.pred` that is a tibble with
16+
#' multiple rows per sub-model.
17+
#' @export
18+
multi_predict <- function(object, ...) {
19+
if (inherits(object$fit, "try-error")) {
20+
warning("Model fit failed; cannot make predictions.", call. = FALSE)
21+
return(NULL)
22+
}
23+
UseMethod("multi_predict")
24+
}
25+
26+
#' @export
27+
#' @rdname multi_predict
28+
multi_predict.default <- function(object, ...)
29+
stop("No `multi_predict` method exists for objects with classes ",
30+
paste0("'", class(), "'", collapse = ", "), call. = FALSE)
31+
32+
#' @export
33+
predict.model_spec <- function(object, ...) {
34+
stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE)
35+
}
36+
37+
#' Tools for models that predict on sub-models
38+
#'
39+
#' `has_multi_predict()` tests to see if an object can make multiple
40+
#' predictions on submodels from the same object. `multi_predict_args()`
41+
#' returns the names of the argments to `multi_predict()` for this model
42+
#' (if any).
43+
#' @param object An object to test.
44+
#' @param ... Not currently used.
45+
#' @return `has_multi_predict()` returns single logical value while
46+
#' `multi_predict()` returns a character vector of argument names (or `NA`
47+
#' if none exist).
48+
#' @keywords internal
49+
#' @examples
50+
#' lm_model_idea <- linear_reg() %>% set_engine("lm")
51+
#' has_multi_predict(lm_model_idea)
52+
#' lm_model_fit <- fit(lm_model_idea, mpg ~ ., data = mtcars)
53+
#' has_multi_predict(lm_model_fit)
54+
#'
55+
#' multi_predict_args(lm_model_fit)
56+
#'
57+
#' library(kknn)
58+
#'
59+
#' knn_fit <-
60+
#' nearest_neighbor(mode = "regression", neighbors = 5) %>%
61+
#' set_engine("kknn") %>%
62+
#' fit(mpg ~ ., mtcars)
63+
#'
64+
#' multi_predict_args(knn_fit)
65+
#'
66+
#' multi_predict(knn_fit, mtcars[1, -1], neighbors = 1:4)$.pred
67+
#' @importFrom utils methods
68+
#' @export
69+
has_multi_predict <- function(object, ...) {
70+
UseMethod("has_multi_predict")
71+
}
72+
73+
#' @export
74+
#' @rdname has_multi_predict
75+
has_multi_predict.default <- function(object, ...) {
76+
FALSE
77+
}
78+
79+
#' @export
80+
#' @rdname has_multi_predict
81+
has_multi_predict.model_fit <- function(object, ...) {
82+
existing_mthds <- utils::methods("multi_predict")
83+
tst <- paste0("multi_predict.", class(object))
84+
any(tst %in% existing_mthds)
85+
}
86+
87+
#' @export
88+
#' @rdname has_multi_predict
89+
has_multi_predict.workflow <- function(object, ...) {
90+
has_multi_predict(object$fit$model$model)
91+
}
92+
93+
94+
#' @rdname has_multi_predict
95+
#' @export
96+
#' @rdname has_multi_predict
97+
multi_predict_args <- function(object, ...) {
98+
UseMethod("multi_predict_args")
99+
}
100+
101+
#' @export
102+
#' @rdname has_multi_predict
103+
multi_predict_args.default <- function(object, ...) {
104+
if (inherits(object, "model_fit")) {
105+
res <- multi_predict_args.model_fit(object, ...)
106+
} else {
107+
res <- NA_character_
108+
}
109+
res
110+
}
111+
112+
#' @export
113+
#' @rdname has_multi_predict
114+
multi_predict_args.model_fit <- function(object, ...) {
115+
existing_mthds <- methods("multi_predict")
116+
cls <- class(object)
117+
tst <- paste0("multi_predict.", cls)
118+
.fn <- tst[tst %in% existing_mthds]
119+
if (length(.fn) == 0) {
120+
return(NA_character_)
121+
}
122+
123+
.fn <- getFromNamespace(.fn, ns = "parsnip")
124+
omit <- c('object', 'new_data', 'type', '...')
125+
args <- names(formals(.fn))
126+
args[!(args %in% omit)]
127+
}
128+
129+
#' @export
130+
#' @rdname has_multi_predict
131+
multi_predict_args.workflow <- function(object, ...) {
132+
object <- object$fit$model$model
133+
134+
}

0 commit comments

Comments
 (0)