Skip to content

Commit fc366b6

Browse files
authored
Merge pull request #101 from topepo/set_engine
set_engine api changes
2 parents 5a796fd + eee4b5c commit fc366b6

File tree

116 files changed

+1808
-1921
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

116 files changed

+1808
-1921
lines changed

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
2-
Version: 0.0.0.9004
3-
Title: A Common API to Modeling and analysis Functions
2+
Version: 0.0.0.9005
3+
Title: A Common API to Modeling and Analysis Functions
44
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. R, spark, stan, etc).
55
Authors@R: c(
66
person("Max", "Kuhn", , "max@rstudio.com", c("aut", "cre")),

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ export(predict_raw)
102102
export(predict_raw.model_fit)
103103
export(rand_forest)
104104
export(set_args)
105+
export(set_engine)
105106
export(set_mode)
106107
export(show_call)
107108
export(surv_reg)

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# parsnip 0.0.0.9005
2+
3+
* The engine, and any associated arguments, are not specified using `set_engine`. There is no `engine` argument
4+
5+
16
# parsnip 0.0.0.9004
27

38
* Arguments to modeling functions are now captured as quosures.

R/arguments.R

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ prune_arg_list <- function(x, whitelist = NULL, modified = character(0)) {
5050
x
5151
}
5252

53-
check_others <- function(args, obj, core_args) {
53+
check_eng_args <- function(args, obj, core_args) {
5454
# Make sure that we are not trying to modify an argument that
5555
# is explicitly protected in the method metadata or arg_key
5656
protected_args <- unique(c(obj$protect, core_args))
@@ -95,10 +95,17 @@ set_args <- function(object, ...) {
9595
if (any(main_args == i)) {
9696
object$args[[i]] <- the_dots[[i]]
9797
} else {
98-
object$others[[i]] <- the_dots[[i]]
98+
object$eng_args[[i]] <- the_dots[[i]]
9999
}
100100
}
101-
object
101+
new_model_spec(
102+
cls = class(object)[1],
103+
args = object$args,
104+
eng_args = object$eng_args,
105+
mode = object$mode,
106+
method = NULL,
107+
engine = object$engine
108+
)
102109
}
103110

104111
#' @rdname set_args
@@ -130,6 +137,6 @@ maybe_eval <- function(x) {
130137

131138
eval_args <- function(spec, ...) {
132139
spec$args <- purrr::map(spec$args, maybe_eval)
133-
spec$others <- purrr::map(spec$others, maybe_eval)
140+
spec$eng_args <- purrr::map(spec$eng_args, maybe_eval)
134141
spec
135142
}

R/boost_tree.R

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#' }
2323
#' These arguments are converted to their specific names at the
2424
#' time that the model is fit. Other options and argument can be
25-
#' set using the `...` slot. If left to their defaults
25+
#' set using the `set_engine` function. If left to their defaults
2626
#' here (`NULL`), the values are taken from the underlying model
2727
#' functions. If parameters need to be modified, `update` can be used
2828
#' in lieu of recreating the object from scratch.
@@ -46,11 +46,6 @@
4646
#' @param sample_size An number for the number (or proportion) of data that is
4747
#' exposed to the fitting routine. For `xgboost`, the sampling is done at at
4848
#' each iteration while `C5.0` samples once during traning.
49-
#' @param ... Other arguments to pass to the specific engine's
50-
#' model fit function (see the Engine Details section below). This
51-
#' should not include arguments defined by the main parameters to
52-
#' this function. For the `update` function, the ellipses can
53-
#' contain the primary arguments or any others.
5449
#' @details
5550
#' The data given to the function are not saved and are only used
5651
#' to determine the _mode_ of the model. For `boost_tree`, the
@@ -63,17 +58,12 @@
6358
#' \item \pkg{Spark}: `"spark"`
6459
#' }
6560
#'
66-
#' Main parameter arguments (and those in `...`) can avoid
67-
#' evaluation until the underlying function is executed by wrapping the
68-
#' argument in [rlang::expr()] (e.g. `mtry = expr(floor(sqrt(p)))`).
69-
#'
7061
#'
7162
#' @section Engine Details:
7263
#'
7364
#' Engines may have pre-set default arguments when executing the
74-
#' model fit call. These can be changed by using the `...`
75-
#' argument to pass in the preferred values. For this type of
76-
#' model, the template of the fit calls are:
65+
#' model fit call. For this type of model, the template of the
66+
#' fit calls are:
7767
#'
7868
#' \pkg{xgboost} classification
7969
#'
@@ -109,7 +99,7 @@
10999
#' reloaded and reattached to the `parsnip` object.
110100
#'
111101
#' @importFrom purrr map_lgl
112-
#' @seealso [varying()], [fit()]
102+
#' @seealso [varying()], [fit()], [set_engine()]
113103
#' @examples
114104
#' boost_tree(mode = "classification", trees = 20)
115105
#' # Parameters can be represented by a placeholder:
@@ -121,11 +111,7 @@ boost_tree <-
121111
mtry = NULL, trees = NULL, min_n = NULL,
122112
tree_depth = NULL, learn_rate = NULL,
123113
loss_reduction = NULL,
124-
sample_size = NULL,
125-
...) {
126-
127-
others <- enquos(...)
128-
114+
sample_size = NULL) {
129115
args <- list(
130116
mtry = enquo(mtry),
131117
trees = enquo(trees),
@@ -136,18 +122,14 @@ boost_tree <-
136122
sample_size = enquo(sample_size)
137123
)
138124

139-
if (!(mode %in% boost_tree_modes))
140-
stop("`mode` should be one of: ",
141-
paste0("'", boost_tree_modes, "'", collapse = ", "),
142-
call. = FALSE)
143-
144-
no_value <- !vapply(others, null_value, logical(1))
145-
others <- others[no_value]
146-
147-
out <- list(args = args, others = others,
148-
mode = mode, method = NULL, engine = NULL)
149-
class(out) <- make_classes("boost_tree")
150-
out
125+
new_model_spec(
126+
"boost_tree",
127+
args,
128+
eng_args = NULL,
129+
mode,
130+
method = NULL,
131+
engine = NULL
132+
)
151133
}
152134

153135
#' @export
@@ -167,6 +149,7 @@ print.boost_tree <- function(x, ...) {
167149
#' @export
168150
#' @inheritParams boost_tree
169151
#' @param object A boosted tree model specification.
152+
#' @param ... Not used for `update`.
170153
#' @param fresh A logical for whether the arguments should be
171154
#' modified in-place of or replaced wholesale.
172155
#' @return An updated model specification.
@@ -183,10 +166,8 @@ update.boost_tree <-
183166
mtry = NULL, trees = NULL, min_n = NULL,
184167
tree_depth = NULL, learn_rate = NULL,
185168
loss_reduction = NULL, sample_size = NULL,
186-
fresh = FALSE,
187-
...) {
188-
189-
others <- enquos(...)
169+
fresh = FALSE, ...) {
170+
update_dot_check(...)
190171

191172
args <- list(
192173
mtry = enquo(mtry),
@@ -209,23 +190,27 @@ update.boost_tree <-
209190
object$args[names(args)] <- args
210191
}
211192

212-
if (length(others) > 0) {
213-
if (fresh)
214-
object$others <- others
215-
else
216-
object$others[names(others)] <- others
217-
}
218-
219-
object
193+
new_model_spec(
194+
"boost_tree",
195+
args = object$args,
196+
eng_args = object$eng_args,
197+
mode = object$mode,
198+
method = NULL,
199+
engine = object$engine
200+
)
220201
}
221202

222203
# ------------------------------------------------------------------------------
223204

224205
#' @export
225-
translate.boost_tree <- function(x, engine, ...) {
206+
translate.boost_tree <- function(x, engine = x$engine, ...) {
207+
if (is.null(engine)) {
208+
message("Used `engine = 'xgboost'` for translation.")
209+
engine <- "xgboost"
210+
}
226211
x <- translate.default(x, engine, ...)
227212

228-
if (x$engine == "spark") {
213+
if (engine == "spark") {
229214
if (x$mode == "unknown")
230215
stop(
231216
"For spark boosted trees models, the mode cannot be 'unknown' ",

R/convert_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ convert_form_to_xy_fit <-function(
7676
if (indicators) {
7777
x <- model.matrix(mod_terms, mod_frame, contrasts)
7878
} else {
79-
# this still ignores -vars in formula ¯\_(ツ)_/¯
79+
# this still ignores -vars in formula
8080
x <- model.frame(mod_terms, data)
8181
y_cols <- attr(mod_terms, "response")
8282
if (length(y_cols) > 0)

R/descriptors.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,11 +318,11 @@ make_descr <- function(object) {
318318
expr_main <- map_lgl(object$args, has_exprs)
319319
else
320320
expr_main <- FALSE
321-
if (length(object$others) > 0)
322-
expr_others <- map_lgl(object$others, has_exprs)
321+
if (length(object$eng_args) > 0)
322+
expr_eng_args <- map_lgl(object$eng_args, has_exprs)
323323
else
324-
expr_others <- FALSE
325-
any(expr_main) | any(expr_others)
324+
expr_eng_args <- FALSE
325+
any(expr_main) | any(expr_eng_args)
326326
}
327327

328328
# Locate descriptors -----------------------------------------------------------
@@ -331,7 +331,7 @@ make_descr <- function(object) {
331331
requires_descrs <- function(object) {
332332
any(c(
333333
map_lgl(object$args, has_any_descrs),
334-
map_lgl(object$others, has_any_descrs)
334+
map_lgl(object$eng_args, has_any_descrs)
335335
))
336336
}
337337

R/engines.R

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,43 @@ check_installs <- function(x) {
5252
}
5353
}
5454
}
55+
56+
#' Declare a computational engine and specific arguments
57+
#'
58+
#' `set_engine` is used to specify which package or system will be used
59+
#' to fit the model, along with any arguments specific to that software.
60+
#'
61+
#' @param object A model specification.
62+
#' @param engine A character string for the software that should
63+
#' be used to fit the model. This is highly dependent on the type
64+
#' of model (e.g. linear regression, random forest, etc.).
65+
#' @param ... Any optional arguments associated with the chosen computational
66+
#' engine. These are captured as quosures and can be `varying()`.
67+
#' @return An updated model specification.
68+
#' @examples
69+
#' # First, set general arguments using the standardized names
70+
#' mod <-
71+
#' logistic_reg(mixture = 1/3) %>%
72+
#' # now say how you want to fit the model and another other options
73+
#' set_engine("glmnet", nlambda = 10)
74+
#' translate(mod, engine = "glmnet")
75+
#' @export
76+
set_engine <- function(object, engine, ...) {
77+
if (!inherits(object, "model_spec")) {
78+
stop("`object` should have class 'model_spec'.", call. = FALSE)
79+
}
80+
if (!is.character(engine) | length(engine) != 1)
81+
stop("`engine` should be a single character value.", call. = FALSE)
82+
83+
object$engine <- engine
84+
object <- check_engine(object)
85+
86+
new_model_spec(
87+
cls = class(object)[1],
88+
args = object$args,
89+
eng_args = enquos(...),
90+
mode = object$mode,
91+
method = NULL,
92+
engine = object$engine
93+
)
94+
}

0 commit comments

Comments
 (0)