@@ -221,6 +221,10 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
221221 }
222222 x <- translate.default(x , engine , ... )
223223
224+ # # -----------------------------------------------------------------------------
225+
226+ arg_vals <- x $ method $ fit $ args
227+
224228 if (engine == " spark" ) {
225229 if (x $ mode == " unknown" ) {
226230 rlang :: abort(
@@ -230,9 +234,23 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
230234 )
231235 )
232236 } else {
233- x $ method $ fit $ args $ type <- x $ mode
237+ arg_vals $ type <- x $ mode
234238 }
235239 }
240+
241+ # # -----------------------------------------------------------------------------
242+ # Protect some arguments based on data dimensions
243+
244+ # min_n parameters
245+ if (any(names(arg_vals ) == " min_instances_per_node" )) {
246+ arg_vals $ min_instances_per_node <-
247+ rlang :: call2(" min_rows" , rlang :: eval_tidy(arg_vals $ min_instances_per_node ), expr(x ))
248+ }
249+
250+ # # -----------------------------------------------------------------------------
251+
252+ x $ method $ fit $ args <- arg_vals
253+
236254 x
237255}
238256
@@ -242,14 +260,18 @@ check_args.boost_tree <- function(object) {
242260
243261 args <- lapply(object $ args , rlang :: eval_tidy )
244262
245- if (is.numeric(args $ trees ) && args $ trees < 0 )
263+ if (is.numeric(args $ trees ) && args $ trees < 0 ) {
246264 rlang :: abort(" `trees` should be >= 1." )
247- if (is.numeric(args $ sample_size ) && (args $ sample_size < 0 | args $ sample_size > 1 ))
265+ }
266+ if (is.numeric(args $ sample_size ) && (args $ sample_size < 0 | args $ sample_size > 1 )) {
248267 rlang :: abort(" `sample_size` should be within [0,1]." )
249- if (is.numeric(args $ tree_depth ) && args $ tree_depth < 0 )
268+ }
269+ if (is.numeric(args $ tree_depth ) && args $ tree_depth < 0 ) {
250270 rlang :: abort(" `tree_depth` should be >= 1." )
251- if (is.numeric(args $ min_n ) && args $ min_n < 0 )
271+ }
272+ if (is.numeric(args $ min_n ) && args $ min_n < 0 ) {
252273 rlang :: abort(" `min_n` should be >= 1." )
274+ }
253275
254276 invisible (object )
255277}
@@ -335,12 +357,19 @@ xgb_train <- function(
335357 colsample_bytree <- 1
336358 }
337359
360+ if (min_child_weight > n ) {
361+ msg <- paste0(min_child_weight , " samples were requested but there were " ,
362+ n , " rows in the data. " , n , " will be used." )
363+ rlang :: warn(msg )
364+ min_child_weight <- min(min_child_weight , n )
365+ }
366+
338367 arg_list <- list (
339368 eta = eta ,
340369 max_depth = max_depth ,
341370 gamma = gamma ,
342371 colsample_bytree = colsample_bytree ,
343- min_child_weight = min_child_weight ,
372+ min_child_weight = min( min_child_weight , n ) ,
344373 subsample = subsample
345374 )
346375
@@ -515,8 +544,21 @@ C5.0_train <-
515544 ctrl_args <- other_args [names(other_args ) %in% c_names ]
516545 fit_args <- other_args [names(other_args ) %in% f_names ]
517546
547+ n <- nrow(x )
548+ if (n == 0 ) {
549+ rlang :: abort(" There are zero rows in the predictor set." )
550+ }
551+
552+
518553 ctrl <- call2(" C5.0Control" , .ns = " C50" )
554+ if (minCases > n ) {
555+ msg <- paste0(minCases , " samples were requested but there were " ,
556+ n , " rows in the data. " , n , " will be used." )
557+ rlang :: warn(msg )
558+ minCases <- n
559+ }
519560 ctrl $ minCases <- minCases
561+
520562 ctrl $ sample <- sample
521563 ctrl <- rlang :: call_modify(ctrl , !!! ctrl_args )
522564
0 commit comments