Skip to content

Commit 3b711f6

Browse files
committed
synced
1 parent 543094e commit 3b711f6

File tree

134 files changed

+7429
-4009
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

134 files changed

+7429
-4009
lines changed

.travis.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,22 @@ sudo: true
88
warnings_are_errors: false
99

1010
r:
11+
- 3.1
12+
- 3.2
13+
- oldrel
1114
- release
1215
- devel
1316

1417
env:
18+
global:
1519
- KERAS_BACKEND="tensorflow"
20+
- MAKEFLAGS="-j 2"
21+
22+
# until we troubleshoot these issues
23+
matrix:
24+
allow_failures:
25+
- r: 3.1
26+
- r: 3.2
1627

1728
r_binary_packages:
1829
- rstan

DESCRIPTION

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Package: parsnip
2-
Version: 0.0.0.9003
2+
Version: 0.0.0.9004
33
Title: A Common API to Modeling and analysis Functions
44
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. R, spark, stan, etc).
55
Authors@R: c(
@@ -25,7 +25,8 @@ Imports:
2525
glue,
2626
magrittr,
2727
stats,
28-
tidyr
28+
tidyr,
29+
globals
2930
Roxygen: list(markdown = TRUE)
3031
RoxygenNote: 6.1.0.9000
3132
Suggests:

NAMESPACE

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,23 @@ S3method(multi_predict,"_lognet")
99
S3method(multi_predict,"_multnet")
1010
S3method(multi_predict,"_xgb.Booster")
1111
S3method(multi_predict,default)
12+
S3method(predict,"_elnet")
13+
S3method(predict,"_lognet")
1214
S3method(predict,"_multnet")
1315
S3method(predict,model_fit)
16+
S3method(predict_class,"_lognet")
1417
S3method(predict_class,model_fit)
18+
S3method(predict_classprob,"_lognet")
19+
S3method(predict_classprob,"_multnet")
1520
S3method(predict_classprob,model_fit)
1621
S3method(predict_confint,model_fit)
22+
S3method(predict_num,"_elnet")
1723
S3method(predict_num,model_fit)
1824
S3method(predict_predint,model_fit)
25+
S3method(predict_quantile,model_fit)
26+
S3method(predict_raw,"_elnet")
27+
S3method(predict_raw,"_lognet")
28+
S3method(predict_raw,"_multnet")
1929
S3method(predict_raw,model_fit)
2030
S3method(print,boost_tree)
2131
S3method(print,linear_reg)
@@ -49,13 +59,23 @@ S3method(varying_args,model_spec)
4959
S3method(varying_args,recipe)
5060
S3method(varying_args,step)
5161
export("%>%")
62+
export(.cols)
63+
export(.dat)
64+
export(.facts)
65+
export(.lvls)
66+
export(.obs)
67+
export(.preds)
68+
export(.x)
69+
export(.y)
70+
export(C5.0_train)
5271
export(boost_tree)
5372
export(check_empty_ellipse)
5473
export(fit)
5574
export(fit.model_spec)
5675
export(fit_control)
5776
export(fit_xy)
5877
export(fit_xy.model_spec)
78+
export(keras_mlp)
5979
export(linear_reg)
6080
export(logistic_reg)
6181
export(make_classes)
@@ -76,6 +96,8 @@ export(predict_num)
7696
export(predict_num.model_fit)
7797
export(predict_predint)
7898
export(predict_predint.model_fit)
99+
export(predict_quantile)
100+
export(predict_quantile.model_fit)
79101
export(predict_raw)
80102
export(predict_raw.model_fit)
81103
export(rand_forest)
@@ -89,14 +111,17 @@ export(varying_args)
89111
export(varying_args.model_spec)
90112
export(varying_args.recipe)
91113
export(varying_args.step)
114+
export(xgb_train)
92115
import(rlang)
93116
importFrom(dplyr,arrange)
94117
importFrom(dplyr,as_tibble)
95118
importFrom(dplyr,bind_cols)
119+
importFrom(dplyr,bind_rows)
96120
importFrom(dplyr,collect)
97121
importFrom(dplyr,full_join)
98122
importFrom(dplyr,funs)
99123
importFrom(dplyr,group_by)
124+
importFrom(dplyr,mutate)
100125
importFrom(dplyr,pull)
101126
importFrom(dplyr,rename)
102127
importFrom(dplyr,rename_at)
@@ -120,6 +145,7 @@ importFrom(purrr,map_dbl)
120145
importFrom(purrr,map_df)
121146
importFrom(purrr,map_dfr)
122147
importFrom(purrr,map_lgl)
148+
importFrom(rlang,eval_tidy)
123149
importFrom(rlang,sym)
124150
importFrom(rlang,syms)
125151
importFrom(stats,.checkMFClasses)
@@ -138,6 +164,7 @@ importFrom(stats,predict)
138164
importFrom(stats,qnorm)
139165
importFrom(stats,qt)
140166
importFrom(stats,quantile)
167+
importFrom(stats,setNames)
141168
importFrom(stats,terms)
142169
importFrom(stats,update)
143170
importFrom(tibble,as_tibble)

NEWS.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
# parsnip 0.0.0.9004
2+
3+
* Arguments to modeling functions are now captured as quosures.
4+
* `others` has been replaced by `...`
5+
* Data descriptor names have beemn changed and are now functions. The descriptor definitions for "cols" and "preds" have been switched.
6+
17
# parsnip 0.0.0.9003
28

39
* `regularization` was changed to `penalty` in a few models to be consistent with [this change](tidymodels/model-implementation-principles@08d3afd).
4-
* if a mode is not chosen in the model specification, it is assigned at the time of fit. [51](https://github.com/topepo/parsnip/issues/51)
10+
* If a mode is not chosen in the model specification, it is assigned at the time of fit. [51](https://github.com/topepo/parsnip/issues/51)
511
* The underlying modeling packages now are loaded by namespace. There will be some exceptions noted in the documentation for each model. For example, in some `predict` methods, the `earth` package will need to be attached to be fully operational.
612

713
# parsnip 0.0.0.9002

R/aaa_spark_helpers.R

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
#' @importFrom dplyr starts_with rename rename_at vars funs
44
format_spark_probs <- function(results, object) {
55
results <- dplyr::select(results, starts_with("probability_"))
6-
results <- dplyr::rename_at(
7-
results,
8-
vars(starts_with("probability_")),
9-
funs(gsub("probability", "pred", .))
10-
)
11-
results
6+
p <- ncol(results)
7+
lvl <- paste0("probability_", 0:(p - 1))
8+
names(lvl) <- paste0("pred_", object$fit$.index_labels)
9+
results %>% rename(!!!syms(lvl))
1210
}
1311

1412
format_spark_class <- function(results, object) {

R/arguments.R

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ check_others <- function(args, obj, core_args) {
8686
#'
8787
#' @export
8888
set_args <- function(object, ...) {
89-
the_dots <- list(...)
89+
the_dots <- enquos(...)
9090
if (length(the_dots) == 0)
9191
stop("Please pass at least one named argument.", call. = FALSE)
9292
main_args <- names(object$args)
@@ -116,4 +116,20 @@ set_mode <- function(object, mode) {
116116
object
117117
}
118118

119+
# ------------------------------------------------------------------------------
119120

121+
#' @importFrom rlang eval_tidy
122+
#' @importFrom purrr map
123+
maybe_eval <- function(x) {
124+
# if descriptors are in `x`, eval fails
125+
y <- try(rlang::eval_tidy(x), silent = TRUE)
126+
if (inherits(y, "try-error"))
127+
y <- x
128+
y
129+
}
130+
131+
eval_args <- function(spec, ...) {
132+
spec$args <- purrr::map(spec$args, maybe_eval)
133+
spec$others <- purrr::map(spec$others, maybe_eval)
134+
spec
135+
}

R/boost_tree.R

Lines changed: 96 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,14 @@
2222
#' }
2323
#' These arguments are converted to their specific names at the
2424
#' time that the model is fit. Other options and argument can be
25-
#' set using the `others` argument. If left to their defaults
25+
#' set using the `...` slot. If left to their defaults
2626
#' here (`NULL`), the values are taken from the underlying model
2727
#' functions. If parameters need to be modified, `update` can be used
2828
#' in lieu of recreating the object from scratch.
2929
#'
3030
#' @param mode A single character string for the type of model.
3131
#' Possible values for this model are "unknown", "regression", or
3232
#' "classification".
33-
#' @param others A named list of arguments to be used by the
34-
#' underlying models (e.g., `xgboost::xgb.train`, etc.). .
3533
#' @param mtry An number for the number (or proportion) of predictors that will
3634
#' be randomly sampled at each split when creating the tree models (`xgboost`
3735
#' only).
@@ -48,8 +46,11 @@
4846
#' @param sample_size An number for the number (or proportion) of data that is
4947
#' exposed to the fitting routine. For `xgboost`, the sampling is done at at
5048
#' each iteration while `C5.0` samples once during traning.
51-
#' @param ... Used for method consistency. Any arguments passed to
52-
#' the ellipses will result in an error. Use `others` instead.
49+
#' @param ... Other arguments to pass to the specific engine's
50+
#' model fit function (see the Engine Details section below). This
51+
#' should not include arguments defined by the main parameters to
52+
#' this function. For the `update` function, the ellipses can
53+
#' contain the primary arguments or any others.
5354
#' @details
5455
#' The data given to the function are not saved and are only used
5556
#' to determine the _mode_ of the model. For `boost_tree`, the
@@ -62,12 +63,15 @@
6263
#' \item \pkg{Spark}: `"spark"`
6364
#' }
6465
#'
65-
#' Main parameter arguments (and those in `others`) can avoid
66+
#' Main parameter arguments (and those in `...`) can avoid
6667
#' evaluation until the underlying function is executed by wrapping the
6768
#' argument in [rlang::expr()] (e.g. `mtry = expr(floor(sqrt(p)))`).
6869
#'
70+
#'
71+
#' @section Engine Details:
72+
#'
6973
#' Engines may have pre-set default arguments when executing the
70-
#' model fit call. These can be changed by using the `others`
74+
#' model fit call. These can be changed by using the `...`
7175
#' argument to pass in the preferred values. For this type of
7276
#' model, the template of the fit calls are:
7377
#'
@@ -114,35 +118,30 @@
114118

115119
boost_tree <-
116120
function(mode = "unknown",
117-
...,
118121
mtry = NULL, trees = NULL, min_n = NULL,
119122
tree_depth = NULL, learn_rate = NULL,
120123
loss_reduction = NULL,
121124
sample_size = NULL,
122-
others = list()) {
123-
check_empty_ellipse(...)
125+
...) {
126+
127+
others <- enquos(...)
128+
129+
args <- list(
130+
mtry = enquo(mtry),
131+
trees = enquo(trees),
132+
min_n = enquo(min_n),
133+
tree_depth = enquo(tree_depth),
134+
learn_rate = enquo(learn_rate),
135+
loss_reduction = enquo(loss_reduction),
136+
sample_size = enquo(sample_size)
137+
)
124138

125139
if (!(mode %in% boost_tree_modes))
126140
stop("`mode` should be one of: ",
127141
paste0("'", boost_tree_modes, "'", collapse = ", "),
128142
call. = FALSE)
129143

130-
if (is.numeric(trees) && trees < 0)
131-
stop("`trees` should be >= 1", call. = FALSE)
132-
if (is.numeric(sample_size) && (sample_size < 0 | sample_size > 1))
133-
stop("`sample_size` should be within [0,1]", call. = FALSE)
134-
if (is.numeric(tree_depth) && tree_depth < 0)
135-
stop("`tree_depth` should be >= 1", call. = FALSE)
136-
if (is.numeric(min_n) && min_n < 0)
137-
stop("`min_n` should be >= 1", call. = FALSE)
138-
139-
args <- list(
140-
mtry = mtry, trees = trees, min_n = min_n, tree_depth = tree_depth,
141-
learn_rate = learn_rate, loss_reduction = loss_reduction,
142-
sample_size = sample_size
143-
)
144-
145-
no_value <- !vapply(others, is.null, logical(1))
144+
no_value <- !vapply(others, null_value, logical(1))
146145
others <- others[no_value]
147146

148147
out <- list(args = args, others = others,
@@ -184,16 +183,20 @@ update.boost_tree <-
184183
mtry = NULL, trees = NULL, min_n = NULL,
185184
tree_depth = NULL, learn_rate = NULL,
186185
loss_reduction = NULL, sample_size = NULL,
187-
others = list(),
188186
fresh = FALSE,
189187
...) {
190-
check_empty_ellipse(...)
188+
189+
others <- enquos(...)
191190

192191
args <- list(
193-
mtry = mtry, trees = trees, min_n = min_n, tree_depth = tree_depth,
194-
learn_rate = learn_rate, loss_reduction = loss_reduction,
195-
sample_size = sample_size
196-
)
192+
mtry = enquo(mtry),
193+
trees = enquo(trees),
194+
min_n = enquo(min_n),
195+
tree_depth = enquo(tree_depth),
196+
learn_rate = enquo(learn_rate),
197+
loss_reduction = enquo(loss_reduction),
198+
sample_size = enquo(sample_size)
199+
)
197200

198201
# TODO make these blocks into a function and document well
199202
if (fresh) {
@@ -235,9 +238,45 @@ translate.boost_tree <- function(x, engine, ...) {
235238
x
236239
}
237240

241+
# ------------------------------------------------------------------------------
242+
243+
check_args.boost_tree <- function(object) {
244+
245+
args <- lapply(object$args, rlang::eval_tidy)
246+
247+
if (is.numeric(args$trees) && args$trees < 0)
248+
stop("`trees` should be >= 1", call. = FALSE)
249+
if (is.numeric(args$sample_size) && (args$sample_size < 0 | args$sample_size > 1))
250+
stop("`sample_size` should be within [0,1]", call. = FALSE)
251+
if (is.numeric(args$tree_depth) && args$tree_depth < 0)
252+
stop("`tree_depth` should be >= 1", call. = FALSE)
253+
if (is.numeric(args$min_n) && args$min_n < 0)
254+
stop("`min_n` should be >= 1", call. = FALSE)
255+
256+
invisible(object)
257+
}
238258

239259
# xgboost helpers --------------------------------------------------------------
240260

261+
#' Boosted trees via xgboost
262+
#'
263+
#' `xgb_train` is a wrapper for `xgboost` tree-based models
264+
#' where all of the model arguments are in the main function.
265+
#'
266+
#' @param x A data frame or matrix of predictors
267+
#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data.
268+
#' @param max_depth An integer for the maximum depth of the tree.
269+
#' @param nrounds An integer for the number of boosting iterations.
270+
#' @param eta A numeric value between zero and one to control the learning rate.
271+
#' @param colsample_bytree Subsampling proportion of columns.
272+
#' @param min_child_weight A numeric value for the minimum sum of instance
273+
#' weights needed in a child to continue to split.
274+
#' @param gamma An number for the minimum loss reduction required to make a
275+
#' further partition on a leaf node of the tree
276+
#' @param subsample Subsampling proportion of rows.
277+
#' @param ... Other options to pass to `xgb.train`.
278+
#' @return A fitted `xgboost` object.
279+
#' @export
241280
xgb_train <- function(
242281
x, y,
243282
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bytree = 1,
@@ -380,6 +419,31 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) {
380419

381420
# C5.0 helpers -----------------------------------------------------------------
382421

422+
#' Boosted trees via C5.0
423+
#'
424+
#' `C5.0_train` is a wrapper for [C50::C5.0()] tree-based models
425+
#' where all of the model arguments are in the main function.
426+
#'
427+
#' @param x A data frame or matrix of predictors.
428+
#' @param y A factor vector with 2 or more levels
429+
#' @param trials An integer specifying the number of boosting
430+
#' iterations. A value of one indicates that a single model is
431+
#' used.
432+
#' @param weights An optional numeric vector of case weights. Note
433+
#' that the data used for the case weights will not be used as a
434+
#' splitting variable in the model (see
435+
#' \url{http://www.rulequest.com/see5-win.html#CASEWEIGHT} for
436+
#' Quinlan's notes on case weights).
437+
#' @param minCases An integer for the smallest number of samples
438+
#' that must be put in at least two of the splits.
439+
#' @param sample A value between (0, .999) that specifies the
440+
#' random proportion of the data should be used to train the model.
441+
#' By default, all the samples are used for model training. Samples
442+
#' not used for training are used to evaluate the accuracy of the
443+
#' model in the printed output.
444+
#' @param ... Other arguments to pass.
445+
#' @return A fitted C5.0 model.
446+
#' @export
383447
C5.0_train <-
384448
function(x, y, weights = NULL, trials = 15, minCases = 2, sample = 0, ...) {
385449
other_args <- list(...)

0 commit comments

Comments
 (0)