Skip to content

Commit e7f8ff6

Browse files
authored
Merge pull request #304 from tidymodels/early-stop
Early stopping in xgboost
2 parents 46e3d09 + 9b12edb commit e7f8ff6

File tree

83 files changed

+1860
-1806
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+1860
-1806
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Imports:
2626
glue,
2727
magrittr,
2828
stats,
29-
tidyr,
29+
tidyr (>= 1.0.0),
3030
globals,
3131
prettyunits,
3232
vctrs (>= 0.2.0)

NEWS.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# parsnip (development version)
22

3+
## Other Changes
4+
5+
* `tidyr` >= 1.0.0 is now required.
6+
7+
## New Features
8+
9+
* A new main argument was added to `boost_tree()` called `stop_iter` for early stopping. The `xgb_train()` function gained arguments for early stopping and a percentage of data to leave out for a validation set.
10+
311
# parsnip 0.1.1
412

513
## New Features

R/boost_tree.R

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#' \item \code{loss_reduction}: The reduction in the loss function required
2020
#' to split further.
2121
#' \item \code{sample_size}: The amount of data exposed to the fitting routine.
22+
#' \item \code{stop_iter}: The number of iterations without improvement before
23+
#' stopping.
2224
#' }
2325
#' These arguments are converted to their specific names at the
2426
#' time that the model is fit. Other options and argument can be
@@ -46,6 +48,8 @@
4648
#' @param sample_size A number for the number (or proportion) of data that is
4749
#' exposed to the fitting routine. For `xgboost`, the sampling is done at at
4850
#' each iteration while `C5.0` samples once during training.
51+
#' @param stop_iter The number of iterations without improvement before
52+
#' stopping (`xgboost` only).
4953
#' @details
5054
#' The data given to the function are not saved and are only used
5155
#' to determine the _mode_ of the model. For `boost_tree()`, the
@@ -87,15 +91,17 @@ boost_tree <-
8791
mtry = NULL, trees = NULL, min_n = NULL,
8892
tree_depth = NULL, learn_rate = NULL,
8993
loss_reduction = NULL,
90-
sample_size = NULL) {
94+
sample_size = NULL,
95+
stop_iter = NULL) {
9196
args <- list(
9297
mtry = enquo(mtry),
9398
trees = enquo(trees),
9499
min_n = enquo(min_n),
95100
tree_depth = enquo(tree_depth),
96101
learn_rate = enquo(learn_rate),
97102
loss_reduction = enquo(loss_reduction),
98-
sample_size = enquo(sample_size)
103+
sample_size = enquo(sample_size),
104+
stop_iter = enquo(stop_iter)
99105
)
100106

101107
new_model_spec(
@@ -155,6 +161,7 @@ update.boost_tree <-
155161
mtry = NULL, trees = NULL, min_n = NULL,
156162
tree_depth = NULL, learn_rate = NULL,
157163
loss_reduction = NULL, sample_size = NULL,
164+
stop_iter = NULL,
158165
fresh = FALSE, ...) {
159166
update_dot_check(...)
160167

@@ -169,7 +176,8 @@ update.boost_tree <-
169176
tree_depth = enquo(tree_depth),
170177
learn_rate = enquo(learn_rate),
171178
loss_reduction = enquo(loss_reduction),
172-
sample_size = enquo(sample_size)
179+
sample_size = enquo(sample_size),
180+
stop_iter = enquo(stop_iter)
173181
)
174182

175183
args <- update_main_parameters(args, parameters)
@@ -242,8 +250,8 @@ check_args.boost_tree <- function(object) {
242250

243251
#' Boosted trees via xgboost
244252
#'
245-
#' `xgb_train` is a wrapper for `xgboost` tree-based models
246-
#' where all of the model arguments are in the main function.
253+
#' `xgb_train` is a wrapper for `xgboost` tree-based models where all of the
254+
#' model arguments are in the main function.
247255
#'
248256
#' @param x A data frame or matrix of predictors
249257
#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data.
@@ -256,16 +264,41 @@ check_args.boost_tree <- function(object) {
256264
#' @param gamma A number for the minimum loss reduction required to make a
257265
#' further partition on a leaf node of the tree
258266
#' @param subsample Subsampling proportion of rows.
267+
#' @param validation A positive number. If on `[0, 1)` the value, `validation`
268+
#' is a random proportion of data in `x` and `y` that are used for performance
269+
#' assessment and potential early stopping. If 1 or greater, it is the _number_
270+
#' of training set samples use for these purposes.
271+
#' @param early_stop An integer or `NULL`. If not `NULL`, it is the number of
272+
#' training iterations without improvement before stopping. If `validation` is
273+
#' used, performance is base on the validation set; otherwise the training set
274+
#' is used.
259275
#' @param ... Other options to pass to `xgb.train`.
260276
#' @return A fitted `xgboost` object.
261277
#' @keywords internal
262278
#' @export
263279
xgb_train <- function(
264280
x, y,
265281
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bytree = 1,
266-
min_child_weight = 1, gamma = 0, subsample = 1, ...) {
282+
min_child_weight = 1, gamma = 0, subsample = 1, validation = 0,
283+
early_stop = NULL, ...) {
284+
285+
if (length(levels(y)) > 2) {
286+
num_class <- length(levels(y))
287+
} else {
288+
num_class <- NULL
289+
}
290+
if (!is.numeric(validation) || validation < 0 || validation >= 1) {
291+
rlang::abort("`validation` should be on [0, 1).")
292+
}
293+
if (!is.null(early_stop)) {
294+
if (early_stop <= 1) {
295+
rlang::abort(paste0("`early_stop` should be on [2, ", nrounds, ")."))
296+
} else if (early_stop >= nrounds) {
297+
early_stop <- nrounds - 1
298+
rlang::warn(paste0("`early_stop` was reduced to ", early_stop, "."))
299+
}
300+
}
267301

268-
num_class <- if (length(levels(y)) > 2) length(levels(y)) else NULL
269302

270303
if (is.numeric(y)) {
271304
loss <- "reg:linear"
@@ -287,7 +320,16 @@ xgb_train <- function(
287320
p <- ncol(x)
288321

289322
if (!inherits(x, "xgb.DMatrix")) {
290-
x <- xgboost::xgb.DMatrix(x, label = y, missing = NA)
323+
if (validation > 0) {
324+
trn_index <- sample(1:n, size = floor(n * validation) + 1)
325+
wlist <-
326+
list(validation = xgboost::xgb.DMatrix(x[-trn_index, ], label = y[-trn_index], missing = NA))
327+
x <- xgboost::xgb.DMatrix(x[trn_index, ], label = y[trn_index], missing = NA)
328+
329+
} else {
330+
x <- xgboost::xgb.DMatrix(x, label = y, missing = NA)
331+
wlist <- list(training = x)
332+
}
291333
} else {
292334
xgboost::setinfo(x, "label", y)
293335
}
@@ -320,9 +362,11 @@ xgb_train <- function(
320362

321363
main_args <- list(
322364
data = quote(x),
365+
watchlist = quote(wlist),
323366
params = arg_list,
324367
nrounds = nrounds,
325-
objective = loss
368+
objective = loss,
369+
early_stopping_rounds = early_stop
326370
)
327371
if (!is.null(num_class)) {
328372
main_args$num_class <- num_class
@@ -334,6 +378,9 @@ xgb_train <- function(
334378
others <- list(...)
335379
others <-
336380
others[!(names(others) %in% c("data", "weights", "nrounds", "num_class", names(arg_list)))]
381+
if (!(any(names(others) == "verbose"))) {
382+
others$verbose <- 0
383+
}
337384
if (length(others) > 0) {
338385
call <- rlang::call_modify(call, !!!others)
339386
}

R/boost_tree_data.R

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@ set_model_arg(
6565
func = list(pkg = "dials", fun = "sample_size"),
6666
has_submodel = FALSE
6767
)
68+
set_model_arg(
69+
model = "boost_tree",
70+
eng = "xgboost",
71+
parsnip = "stop_iter",
72+
original = "early_stop",
73+
func = list(pkg = "dials", fun = "stop_iter"),
74+
has_submodel = FALSE
75+
)
76+
6877

6978
set_fit(
7079
model = "boost_tree",

_pkgdown.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
template:
22
package: tidytemplate
33
params:
4-
part_of: <a href="https://github.com/tidymodels">tidymodels</a>
4+
part_of: <a href="https://tidymodels.org">tidymodels</a>
55
footer: <code>parsnip</code> is a part of the <strong>tidymodels</strong> ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.
66

77
# https://github.com/tidyverse/tidytemplate for css

docs/dev/404.html

Lines changed: 9 additions & 21 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/dev/CODE_OF_CONDUCT.html

Lines changed: 17 additions & 29 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)