Skip to content

Commit 199385b

Browse files
committed
default engine changes for #513
1 parent 7d78009 commit 199385b

37 files changed

+177
-81
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# parsnip (development version)
22

3+
* Each model now has a default engine that is used when the model is defined. The default for each model is listed in the help documents. This also adds functionality to change to declare an engine in the model specification function. `set_engine()` is still required if engine-specific arguments need to be added. (#513)
4+
5+
* The default engine for `multinom_reg()` was changed to `nnet`.
6+
37
* The helper functions `.convert_form_to_xy_fit()`, `.convert_form_to_xy_new()`, `.convert_xy_to_form_fit()`, and `.convert_xy_to_form_new()` for converting between formula and matrix interface are now exported for developer use (#508).
48

59
* Fix bug in `augment()` when non-predictor, non-outcome variables are included in data (#510).

R/boost_tree.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
#' @param mode A single character string for the type of model.
3333
#' Possible values for this model are "unknown", "regression", or
3434
#' "classification".
35+
#' @param engine A character string for the method of fitting. Possible engines
36+
#' are listed above. The default for this model is `"xgboost"`.
3537
#' @param mtry A number for the number (or proportion) of predictors that will
3638
#' be randomly sampled at each split when creating the tree models (`xgboost`
3739
#' only).
@@ -92,6 +94,7 @@
9294

9395
boost_tree <-
9496
function(mode = "unknown",
97+
engine = "xgboost",
9598
mtry = NULL, trees = NULL, min_n = NULL,
9699
tree_depth = NULL, learn_rate = NULL,
97100
loss_reduction = NULL,
@@ -114,7 +117,7 @@ boost_tree <-
114117
eng_args = NULL,
115118
mode,
116119
method = NULL,
117-
engine = NULL
120+
engine = engine
118121
)
119122
}
120123

R/decision_tree.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#' @param mode A single character string for the type of model.
2525
#' Possible values for this model are "unknown", "regression", or
2626
#' "classification".
27+
#' @param engine A character string for the method of fitting. Possible engines
28+
#' are listed above. The default for this model is `"rpart"`.
2729
#' @param cost_complexity A positive number for the the cost/complexity
2830
#' parameter (a.k.a. `Cp`) used by CART models (`rpart` only).
2931
#' @param tree_depth An integer for maximum depth of the tree.
@@ -69,7 +71,8 @@
6971
#' @export
7072

7173
decision_tree <-
72-
function(mode = "unknown", cost_complexity = NULL, tree_depth = NULL, min_n = NULL) {
74+
function(mode = "unknown", engine = "rpart", cost_complexity = NULL,
75+
tree_depth = NULL, min_n = NULL) {
7376

7477
args <- list(
7578
cost_complexity = enquo(cost_complexity),
@@ -83,7 +86,7 @@ decision_tree <-
8386
eng_args = NULL,
8487
mode = mode,
8588
method = NULL,
86-
engine = NULL
89+
engine = engine
8790
)
8891
}
8992

R/linear_reg.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#' in lieu of recreating the object from scratch.
1919
#' @param mode A single character string for the type of model.
2020
#' The only possible value for this model is "regression".
21+
#' @param engine A character string for the method of fitting. Possible engines
22+
#' are listed above. The default for this model is `"lm"`.
2123
#' @param penalty A non-negative number representing the total
2224
#' amount of regularization (`glmnet`, `keras`, and `spark` only).
2325
#' For `keras` models, this corresponds to purely L2 regularization
@@ -70,6 +72,7 @@
7072
#' @importFrom purrr map_lgl
7173
linear_reg <-
7274
function(mode = "regression",
75+
engine = "lm",
7376
penalty = NULL,
7477
mixture = NULL) {
7578

@@ -84,7 +87,7 @@ linear_reg <-
8487
eng_args = NULL,
8588
mode = mode,
8689
method = NULL,
87-
engine = NULL
90+
engine = engine
8891
)
8992
}
9093

R/logistic_reg.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#' in lieu of recreating the object from scratch.
1919
#' @param mode A single character string for the type of model.
2020
#' The only possible value for this model is "classification".
21+
#' @param engine A character string for the method of fitting. Possible engines
22+
#' are listed above. The default for this model is `"glm"`.
2123
#' @param penalty A non-negative number representing the total
2224
#' amount of regularization (`glmnet`, `LiblineaR`, `keras`, and `spark` only).
2325
#' For `keras` models, this corresponds to purely L2 regularization
@@ -69,6 +71,7 @@
6971
#' @importFrom purrr map_lgl
7072
logistic_reg <-
7173
function(mode = "classification",
74+
engine = "glm",
7275
penalty = NULL,
7376
mixture = NULL) {
7477

@@ -83,7 +86,7 @@ logistic_reg <-
8386
eng_args = NULL,
8487
mode = mode,
8588
method = NULL,
86-
engine = NULL
89+
engine = engine
8790
)
8891
}
8992

R/mars.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#' @param mode A single character string for the type of model.
2626
#' Possible values for this model are "unknown", "regression", or
2727
#' "classification".
28+
#' @param engine A character string for the method of fitting. Possible engines
29+
#' are listed above. The default for this model is `"earth"`.
2830
#' @param num_terms The number of features that will be retained in the
2931
#' final model, including the intercept.
3032
#' @param prod_degree The highest possible interaction degree.
@@ -45,7 +47,7 @@
4547
#' mars(mode = "regression", num_terms = 5)
4648
#' @export
4749
mars <-
48-
function(mode = "unknown",
50+
function(mode = "unknown", engine = "earth",
4951
num_terms = NULL, prod_degree = NULL, prune_method = NULL) {
5052

5153
args <- list(
@@ -60,7 +62,7 @@ mars <-
6062
eng_args = NULL,
6163
mode = mode,
6264
method = NULL,
63-
engine = NULL
65+
engine = engine
6466
)
6567
}
6668

R/mlp.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
#' @param mode A single character string for the type of model.
3131
#' Possible values for this model are "unknown", "regression", or
3232
#' "classification".
33+
#' @param engine A character string for the method of fitting. Possible engines
34+
#' are listed above. The default for this model is `"nnet"`.
3335
#' @param hidden_units An integer for the number of units in the hidden model.
3436
#' @param penalty A non-negative numeric value for the amount of weight
3537
#' decay.
@@ -63,7 +65,7 @@
6365
#' @export
6466

6567
mlp <-
66-
function(mode = "unknown",
68+
function(mode = "unknown", engine = "nnet",
6769
hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL,
6870
activation = NULL) {
6971

@@ -81,7 +83,7 @@ mlp <-
8183
eng_args = NULL,
8284
mode = mode,
8385
method = NULL,
84-
engine = NULL
86+
engine = engine
8587
)
8688
}
8789

R/multinom_reg.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#' in lieu of recreating the object from scratch.
1919
#' @param mode A single character string for the type of model.
2020
#' The only possible value for this model is "classification".
21+
#' @param engine A character string for the method of fitting. Possible engines
22+
#' are listed above. The default for this model is `"nnet"`.
2123
#' @param penalty A non-negative number representing the total
2224
#' amount of regularization (`glmnet`, `keras`, and `spark` only).
2325
#' For `keras` models, this corresponds to purely L2 regularization
@@ -33,7 +35,7 @@
3335
#' The model can be created using the `fit()` function using the
3436
#' following _engines_:
3537
#' \itemize{
36-
#' \item \pkg{R}: `"glmnet"` (the default), `"nnet"`
38+
#' \item \pkg{R}: `"glmnet"`, `"nnet"` (the default)
3739
#' \item \pkg{Spark}: `"spark"`
3840
#' \item \pkg{keras}: `"keras"`
3941
#' }
@@ -64,6 +66,7 @@
6466
#' @importFrom purrr map_lgl
6567
multinom_reg <-
6668
function(mode = "classification",
69+
engine = "nnet",
6770
penalty = NULL,
6871
mixture = NULL) {
6972

@@ -78,7 +81,7 @@ multinom_reg <-
7881
eng_args = NULL,
7982
mode = mode,
8083
method = NULL,
81-
engine = NULL
84+
engine = engine
8285
)
8386
}
8487

R/nearest_neighbor.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
#' @param mode A single character string for the type of model.
2727
#' Possible values for this model are `"unknown"`, `"regression"`, or
2828
#' `"classification"`.
29-
#'
29+
#' @param engine A character string for the method of fitting. Possible engines
30+
#' are listed above. The default for this model is `"kknn"`.
3031
#' @param neighbors A single integer for the number of neighbors
3132
#' to consider (often called `k`). For \pkg{kknn}, a value of 5
3233
#' is used if `neighbors` is not specified.
@@ -57,6 +58,7 @@
5758
#'
5859
#' @export
5960
nearest_neighbor <- function(mode = "unknown",
61+
engine = "kknn",
6062
neighbors = NULL,
6163
weight_func = NULL,
6264
dist_power = NULL) {
@@ -72,7 +74,7 @@ nearest_neighbor <- function(mode = "unknown",
7274
eng_args = NULL,
7375
mode = mode,
7476
method = NULL,
75-
engine = NULL
77+
engine = engine
7678
)
7779
}
7880

R/proportional_hazards.R

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#'
1919
#' @param mode A single character string for the type of model.
2020
#' Possible values for this model are "unknown", or "censored regression".
21+
#' @param engine A character string for the method of fitting. Possible engines
22+
#' are listed above. The default for this model is `"survival"`.
2123
#' @inheritParams linear_reg
2224
#'
2325
#' @details
@@ -29,9 +31,11 @@
2931
#' show_engines("proportional_hazards")
3032
#' @keywords internal
3133
#' @export
32-
proportional_hazards <- function(mode = "censored regression",
33-
penalty = NULL,
34-
mixture = NULL) {
34+
proportional_hazards <- function(
35+
mode = "censored regression",
36+
engine = "survival",
37+
penalty = NULL,
38+
mixture = NULL) {
3539

3640
args <- list(
3741
penalty = enquo(penalty),
@@ -44,7 +48,7 @@ proportional_hazards <- function(mode = "censored regression",
4448
eng_args = NULL,
4549
mode = mode,
4650
method = NULL,
47-
engine = NULL
51+
engine = engine
4852
)
4953
}
5054

0 commit comments

Comments
 (0)