Skip to content

Commit ee4b955

Browse files
committed
function and documentation changes for set_engine
1 parent 5a796fd commit ee4b955

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+817
-1004
lines changed

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
2-
Version: 0.0.0.9004
3-
Title: A Common API to Modeling and analysis Functions
2+
Version: 0.0.0.9005
3+
Title: A Common API to Modeling and Analysis Functions
44
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. R, spark, stan, etc).
55
Authors@R: c(
66
person("Max", "Kuhn", , "max@rstudio.com", c("aut", "cre")),

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ export(predict_raw)
102102
export(predict_raw.model_fit)
103103
export(rand_forest)
104104
export(set_args)
105+
export(set_engine)
105106
export(set_mode)
106107
export(show_call)
107108
export(surv_reg)

R/boost_tree.R

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#' }
2323
#' These arguments are converted to their specific names at the
2424
#' time that the model is fit. Other options and argument can be
25-
#' set using the `...` slot. If left to their defaults
25+
#' set using the `set_engine` function. If left to their defaults
2626
#' here (`NULL`), the values are taken from the underlying model
2727
#' functions. If parameters need to be modified, `update` can be used
2828
#' in lieu of recreating the object from scratch.
@@ -46,11 +46,6 @@
4646
#' @param sample_size An number for the number (or proportion) of data that is
4747
#' exposed to the fitting routine. For `xgboost`, the sampling is done at at
4848
#' each iteration while `C5.0` samples once during traning.
49-
#' @param ... Other arguments to pass to the specific engine's
50-
#' model fit function (see the Engine Details section below). This
51-
#' should not include arguments defined by the main parameters to
52-
#' this function. For the `update` function, the ellipses can
53-
#' contain the primary arguments or any others.
5449
#' @details
5550
#' The data given to the function are not saved and are only used
5651
#' to determine the _mode_ of the model. For `boost_tree`, the
@@ -63,17 +58,12 @@
6358
#' \item \pkg{Spark}: `"spark"`
6459
#' }
6560
#'
66-
#' Main parameter arguments (and those in `...`) can avoid
67-
#' evaluation until the underlying function is executed by wrapping the
68-
#' argument in [rlang::expr()] (e.g. `mtry = expr(floor(sqrt(p)))`).
69-
#'
7061
#'
7162
#' @section Engine Details:
7263
#'
7364
#' Engines may have pre-set default arguments when executing the
74-
#' model fit call. These can be changed by using the `...`
75-
#' argument to pass in the preferred values. For this type of
76-
#' model, the template of the fit calls are:
65+
#' model fit call. For this type of model, the template of the
66+
#' fit calls are:
7767
#'
7868
#' \pkg{xgboost} classification
7969
#'
@@ -109,7 +99,7 @@
10999
#' reloaded and reattached to the `parsnip` object.
110100
#'
111101
#' @importFrom purrr map_lgl
112-
#' @seealso [varying()], [fit()]
102+
#' @seealso [varying()], [fit()], [set_engine()]
113103
#' @examples
114104
#' boost_tree(mode = "classification", trees = 20)
115105
#' # Parameters can be represented by a placeholder:
@@ -121,11 +111,7 @@ boost_tree <-
121111
mtry = NULL, trees = NULL, min_n = NULL,
122112
tree_depth = NULL, learn_rate = NULL,
123113
loss_reduction = NULL,
124-
sample_size = NULL,
125-
...) {
126-
127-
others <- enquos(...)
128-
114+
sample_size = NULL) {
129115
args <- list(
130116
mtry = enquo(mtry),
131117
trees = enquo(trees),
@@ -141,10 +127,7 @@ boost_tree <-
141127
paste0("'", boost_tree_modes, "'", collapse = ", "),
142128
call. = FALSE)
143129

144-
no_value <- !vapply(others, null_value, logical(1))
145-
others <- others[no_value]
146-
147-
out <- list(args = args, others = others,
130+
out <- list(args = args, others = NULL,
148131
mode = mode, method = NULL, engine = NULL)
149132
class(out) <- make_classes("boost_tree")
150133
out
@@ -183,11 +166,7 @@ update.boost_tree <-
183166
mtry = NULL, trees = NULL, min_n = NULL,
184167
tree_depth = NULL, learn_rate = NULL,
185168
loss_reduction = NULL, sample_size = NULL,
186-
fresh = FALSE,
187-
...) {
188-
189-
others <- enquos(...)
190-
169+
fresh = FALSE) {
191170
args <- list(
192171
mtry = enquo(mtry),
193172
trees = enquo(trees),
@@ -209,23 +188,20 @@ update.boost_tree <-
209188
object$args[names(args)] <- args
210189
}
211190

212-
if (length(others) > 0) {
213-
if (fresh)
214-
object$others <- others
215-
else
216-
object$others[names(others)] <- others
217-
}
218-
219191
object
220192
}
221193

222194
# ------------------------------------------------------------------------------
223195

224196
#' @export
225-
translate.boost_tree <- function(x, engine, ...) {
197+
translate.boost_tree <- function(x, engine = x$engine, ...) {
198+
if (is.null(engine)) {
199+
message("Used `engine = 'xgboost'` for translation.")
200+
engine <- "xgboost"
201+
}
226202
x <- translate.default(x, engine, ...)
227203

228-
if (x$engine == "spark") {
204+
if (engine == "spark") {
229205
if (x$mode == "unknown")
230206
stop(
231207
"For spark boosted trees models, the mode cannot be 'unknown' ",

R/fit.R

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,11 @@
1717
#' below). A data frame containing all relevant variables (e.g.
1818
#' outcome(s), predictors, case weights, etc). Note: when needed, a
1919
#' \emph{named argument} should be used.
20-
#' @param engine A character string for the software that should
21-
#' be used to fit the model. This is highly dependent on the type
22-
#' of model (e.g. linear regression, random forest, etc.).
2320
#' @param control A named list with elements `verbosity` and
2421
#' `catch`. See [fit_control()].
2522
#' @param ... Not currently used; values passed here will be
2623
#' ignored. Other options required to fit the model should be
27-
#' passed using the `others` argument in the original model
28-
#' specification.
24+
#' passed using `set_engine`.
2925
#' @details `fit` and `fit_xy` substitute the current arguments in the model
3026
#' specification into the computational engine's code, checks them
3127
#' for validity, then fits the model using the data and the
@@ -92,11 +88,13 @@ fit.model_spec <-
9288
function(object,
9389
formula = NULL,
9490
data = NULL,
95-
engine = object$engine,
9691
control = fit_control(),
9792
...
9893
) {
9994
dots <- quos(...)
95+
if (any(names(dots) == "engine"))
96+
stop("Use `set_engine` to supply the engine.", call. = FALSE)
97+
10098
if (all(c("x", "y") %in% names(dots)))
10199
stop("`fit.model_spec` is for the formula methods. Use `fit_xy` instead.",
102100
call. = FALSE)
@@ -109,10 +107,8 @@ fit.model_spec <-
109107
eval_env$formula <- formula
110108
fit_interface <-
111109
check_interface(eval_env$formula, eval_env$data, cl, object)
112-
object$engine <- engine
113-
object <- check_engine(object)
114110

115-
if (engine == "spark" && !inherits(eval_env$data, "tbl_spark"))
111+
if (object$engine == "spark" && !inherits(eval_env$data, "tbl_spark"))
116112
stop(
117113
"spark objects can only be used with the formula interface to `fit` ",
118114
"with a spark data object.", call. = FALSE
@@ -122,7 +118,7 @@ fit.model_spec <-
122118
object <- get_method(object, engine = object$engine)
123119

124120
check_installs(object) # TODO rewrite with pkgman
125-
# TODO Should probably just load the namespace
121+
126122
load_libs(object, control$verbosity < 2)
127123

128124
interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_")
@@ -178,20 +174,20 @@ fit_xy.model_spec <-
178174
function(object,
179175
x = NULL,
180176
y = NULL,
181-
engine = object$engine,
182177
control = fit_control(),
183178
...
184179
) {
180+
dots <- quos(...)
181+
if (any(names(dots) == "engine"))
182+
stop("Use `set_engine` to supply the engine.", call. = FALSE)
185183

186184
cl <- match.call(expand.dots = TRUE)
187185
eval_env <- rlang::env()
188186
eval_env$x <- x
189187
eval_env$y <- y
190188
fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object)
191-
object$engine <- engine
192-
object <- check_engine(object)
193189

194-
if (engine == "spark")
190+
if (object$engine == "spark")
195191
stop(
196192
"spark objects can only be used with the formula interface to `fit` ",
197193
"with a spark data object.", call. = FALSE
@@ -201,7 +197,7 @@ fit_xy.model_spec <-
201197
object <- get_method(object, engine = object$engine)
202198

203199
check_installs(object) # TODO rewrite with pkgman
204-
# TODO Should probably just load the namespace
200+
205201
load_libs(object, control$verbosity < 2)
206202

207203
interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_")

R/linear_reg.R

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#' }
1313
#' These arguments are converted to their specific names at the
1414
#' time that the model is fit. Other options and argument can be
15-
#' set using the `...` slot. If left to their defaults
15+
#' set using `set_engine`. If left to their defaults
1616
#' here (`NULL`), the values are taken from the underlying model
1717
#' functions. If parameters need to be modified, `update` can be used
1818
#' in lieu of recreating the object from scratch.
@@ -25,7 +25,6 @@
2525
#' represents the proportion of regularization that is used for the
2626
#' L2 penalty (i.e. weight decay, or ridge regression) versus L1
2727
#' (the lasso) (`glmnet` and `spark` only).
28-
#'
2928
#' @details
3029
#' The data given to the function are not saved and are only used
3130
#' to determine the _mode_ of the model. For `linear_reg`, the
@@ -42,8 +41,7 @@
4241
#' @section Engine Details:
4342
#'
4443
#' Engines may have pre-set default arguments when executing the
45-
#' model fit call. These can be changed by using the `...`
46-
#' argument to pass in the preferred values. For this type of
44+
#' model fit call. For this type of
4745
#' model, the template of the fit calls are:
4846
#'
4947
#' \pkg{lm}
@@ -92,7 +90,7 @@
9290
#' separately saved to disk. In a new session, the object can be
9391
#' reloaded and reattached to the `parsnip` object.
9492
#'
95-
#' @seealso [varying()], [fit()]
93+
#' @seealso [varying()], [fit()], [set_engine()]
9694
#' @examples
9795
#' linear_reg()
9896
#' # Parameters can be represented by a placeholder:
@@ -102,10 +100,7 @@
102100
linear_reg <-
103101
function(mode = "regression",
104102
penalty = NULL,
105-
mixture = NULL,
106-
...) {
107-
108-
others <- enquos(...)
103+
mixture = NULL) {
109104

110105
args <- list(
111106
penalty = enquo(penalty),
@@ -119,13 +114,10 @@ linear_reg <-
119114
call. = FALSE
120115
)
121116

122-
no_value <- !vapply(others, is.null, logical(1))
123-
others <- others[no_value]
124-
125117
# write a constructor function
126118
out <- list(
127119
args = args,
128-
others = others,
120+
others = NULL,
129121
mode = mode,
130122
method = NULL,
131123
engine = NULL
@@ -162,10 +154,7 @@ print.linear_reg <- function(x, ...) {
162154
update.linear_reg <-
163155
function(object,
164156
penalty = NULL, mixture = NULL,
165-
fresh = FALSE,
166-
...) {
167-
168-
others <- enquos(...)
157+
fresh = FALSE) {
169158

170159
args <- list(
171160
penalty = enquo(penalty),
@@ -182,13 +171,6 @@ update.linear_reg <-
182171
object$args[names(args)] <- args
183172
}
184173

185-
if (length(others) > 0) {
186-
if (fresh)
187-
object$others <- others
188-
else
189-
object$others[names(others)] <- others
190-
}
191-
192174
object
193175
}
194176

R/logistic_reg.R

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#' }
1313
#' These arguments are converted to their specific names at the
1414
#' time that the model is fit. Other options and argument can be
15-
#' set using the `...` slot. If left to their defaults
15+
#' set using `set_engine`. If left to their defaults
1616
#' here (`NULL`), the values are taken from the underlying model
1717
#' functions. If parameters need to be modified, `update` can be used
1818
#' in lieu of recreating the object from scratch.
@@ -39,8 +39,7 @@
3939
#' @section Engine Details:
4040
#'
4141
#' Engines may have pre-set default arguments when executing the
42-
#' model fit call. These can be changed by using the `...`
43-
#' argument to pass in the preferred values. For this type of
42+
#' model fit call. For this type of
4443
#' model, the template of the fit calls are:
4544
#'
4645
#' \pkg{glm}
@@ -100,10 +99,7 @@
10099
logistic_reg <-
101100
function(mode = "classification",
102101
penalty = NULL,
103-
mixture = NULL,
104-
...) {
105-
106-
others <- enquos(...)
102+
mixture = NULL) {
107103

108104
args <- list(
109105
penalty = enquo(penalty),
@@ -117,13 +113,10 @@ logistic_reg <-
117113
call. = FALSE
118114
)
119115

120-
no_value <- !vapply(others, is.null, logical(1))
121-
others <- others[no_value]
122-
123116
# write a constructor function
124117
out <- list(
125118
args = args,
126-
others = others,
119+
others = NULL,
127120
mode = mode,
128121
method = NULL,
129122
engine = NULL
@@ -160,10 +153,7 @@ print.logistic_reg <- function(x, ...) {
160153
update.logistic_reg <-
161154
function(object,
162155
penalty = NULL, mixture = NULL,
163-
fresh = FALSE,
164-
...) {
165-
166-
others <- enquos(...)
156+
fresh = FALSE) {
167157

168158
args <- list(
169159
penalty = enquo(penalty),
@@ -180,13 +170,6 @@ update.logistic_reg <-
180170
object$args[names(args)] <- args
181171
}
182172

183-
if (length(others) > 0) {
184-
if (fresh)
185-
object$others <- others
186-
else
187-
object$others[names(others)] <- others
188-
}
189-
190173
object
191174
}
192175

0 commit comments

Comments
 (0)