1919# ' \item \code{loss_reduction}: The reduction in the loss function required
2020# ' to split further.
2121# ' \item \code{sample_size}: The amount of data exposed to the fitting routine.
22+ # ' \item \code{stop_iter}: The number of iterations without improvement before
23+ # ' stopping.
2224# ' }
2325# ' These arguments are converted to their specific names at the
2426# ' time that the model is fit. Other options and argument can be
4648# ' @param sample_size A number for the number (or proportion) of data that is
4749# ' exposed to the fitting routine. For `xgboost`, the sampling is done at at
4850# ' each iteration while `C5.0` samples once during training.
51+ # ' @param stop_iter The number of iterations without improvement before
52+ # ' stopping (`xgboost` only).
4953# ' @details
5054# ' The data given to the function are not saved and are only used
5155# ' to determine the _mode_ of the model. For `boost_tree()`, the
@@ -87,15 +91,17 @@ boost_tree <-
8791 mtry = NULL , trees = NULL , min_n = NULL ,
8892 tree_depth = NULL , learn_rate = NULL ,
8993 loss_reduction = NULL ,
90- sample_size = NULL ) {
94+ sample_size = NULL ,
95+ stop_iter = NULL ) {
9196 args <- list (
9297 mtry = enquo(mtry ),
9398 trees = enquo(trees ),
9499 min_n = enquo(min_n ),
95100 tree_depth = enquo(tree_depth ),
96101 learn_rate = enquo(learn_rate ),
97102 loss_reduction = enquo(loss_reduction ),
98- sample_size = enquo(sample_size )
103+ sample_size = enquo(sample_size ),
104+ stop_iter = enquo(stop_iter )
99105 )
100106
101107 new_model_spec(
@@ -155,6 +161,7 @@ update.boost_tree <-
155161 mtry = NULL , trees = NULL , min_n = NULL ,
156162 tree_depth = NULL , learn_rate = NULL ,
157163 loss_reduction = NULL , sample_size = NULL ,
164+ stop_iter = NULL ,
158165 fresh = FALSE , ... ) {
159166 update_dot_check(... )
160167
@@ -169,7 +176,8 @@ update.boost_tree <-
169176 tree_depth = enquo(tree_depth ),
170177 learn_rate = enquo(learn_rate ),
171178 loss_reduction = enquo(loss_reduction ),
172- sample_size = enquo(sample_size )
179+ sample_size = enquo(sample_size ),
180+ stop_iter = enquo(stop_iter )
173181 )
174182
175183 args <- update_main_parameters(args , parameters )
@@ -242,8 +250,8 @@ check_args.boost_tree <- function(object) {
242250
243251# ' Boosted trees via xgboost
244252# '
245- # ' `xgb_train` is a wrapper for `xgboost` tree-based models
246- # ' where all of the model arguments are in the main function.
253+ # ' `xgb_train` is a wrapper for `xgboost` tree-based models where all of the
254+ # ' model arguments are in the main function.
247255# '
248256# ' @param x A data frame or matrix of predictors
249257# ' @param y A vector (factor or numeric) or matrix (numeric) of outcome data.
@@ -256,16 +264,41 @@ check_args.boost_tree <- function(object) {
256264# ' @param gamma A number for the minimum loss reduction required to make a
257265# ' further partition on a leaf node of the tree
258266# ' @param subsample Subsampling proportion of rows.
267+ # ' @param validation A positive number. If on `[0, 1)` the value, `validation`
268+ # ' is a random proportion of data in `x` and `y` that are used for performance
269+ # ' assessment and potential early stopping. If 1 or greater, it is the _number_
270+ # ' of training set samples use for these purposes.
271+ # ' @param early_stop An integer or `NULL`. If not `NULL`, it is the number of
272+ # ' training iterations without improvement before stopping. If `validation` is
273+ # ' used, performance is base on the validation set; otherwise the training set
274+ # ' is used.
259275# ' @param ... Other options to pass to `xgb.train`.
260276# ' @return A fitted `xgboost` object.
261277# ' @keywords internal
262278# ' @export
263279xgb_train <- function (
264280 x , y ,
265281 max_depth = 6 , nrounds = 15 , eta = 0.3 , colsample_bytree = 1 ,
266- min_child_weight = 1 , gamma = 0 , subsample = 1 , ... ) {
282+ min_child_weight = 1 , gamma = 0 , subsample = 1 , validation = 0 ,
283+ early_stop = NULL , ... ) {
284+
285+ if (length(levels(y )) > 2 ) {
286+ num_class <- length(levels(y ))
287+ } else {
288+ num_class <- NULL
289+ }
290+ if (! is.numeric(validation ) || validation < 0 || validation > = 1 ) {
291+ rlang :: abort(" `validation` should be on [0, 1)." )
292+ }
293+ if (! is.null(early_stop )) {
294+ if (early_stop < = 1 ) {
295+ rlang :: abort(paste0(" `early_stop` should be on [2, " , nrounds , " )." ))
296+ } else if (early_stop > = nrounds ) {
297+ early_stop <- nrounds - 1
298+ rlang :: warn(paste0(" `early_stop` was reduced to " , early_stop , " ." ))
299+ }
300+ }
267301
268- num_class <- if (length(levels(y )) > 2 ) length(levels(y )) else NULL
269302
270303 if (is.numeric(y )) {
271304 loss <- " reg:linear"
@@ -287,7 +320,16 @@ xgb_train <- function(
287320 p <- ncol(x )
288321
289322 if (! inherits(x , " xgb.DMatrix" )) {
290- x <- xgboost :: xgb.DMatrix(x , label = y , missing = NA )
323+ if (validation > 0 ) {
324+ trn_index <- sample(1 : n , size = floor(n * validation ) + 1 )
325+ wlist <-
326+ list (validation = xgboost :: xgb.DMatrix(x [- trn_index , ], label = y [- trn_index ], missing = NA ))
327+ x <- xgboost :: xgb.DMatrix(x [trn_index , ], label = y [trn_index ], missing = NA )
328+
329+ } else {
330+ x <- xgboost :: xgb.DMatrix(x , label = y , missing = NA )
331+ wlist <- list (training = x )
332+ }
291333 } else {
292334 xgboost :: setinfo(x , " label" , y )
293335 }
@@ -320,9 +362,11 @@ xgb_train <- function(
320362
321363 main_args <- list (
322364 data = quote(x ),
365+ watchlist = quote(wlist ),
323366 params = arg_list ,
324367 nrounds = nrounds ,
325- objective = loss
368+ objective = loss ,
369+ early_stopping_rounds = early_stop
326370 )
327371 if (! is.null(num_class )) {
328372 main_args $ num_class <- num_class
@@ -334,6 +378,9 @@ xgb_train <- function(
334378 others <- list (... )
335379 others <-
336380 others [! (names(others ) %in% c(" data" , " weights" , " nrounds" , " num_class" , names(arg_list )))]
381+ if (! (any(names(others ) == " verbose" ))) {
382+ others $ verbose <- 0
383+ }
337384 if (length(others ) > 0 ) {
338385 call <- rlang :: call_modify(call , !!! others )
339386 }
0 commit comments