@@ -221,6 +221,10 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
221221 }
222222 x <- translate.default(x , engine , ... )
223223
224+ # # -----------------------------------------------------------------------------
225+
226+ arg_vals <- x $ method $ fit $ args
227+
224228 if (engine == " spark" ) {
225229 if (x $ mode == " unknown" ) {
226230 rlang :: abort(
@@ -230,9 +234,21 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
230234 )
231235 )
232236 } else {
233- x $ method $ fit $ args $ type <- x $ mode
237+ arg_vals $ type <- x $ mode
234238 }
235239 }
240+
241+ # # -----------------------------------------------------------------------------
242+ # Protect some arguments based on data dimensions
243+
244+ # min_n parameters
245+ if (any(names(arg_vals ) == " min_instances_per_node" )) {
246+ arg_vals $ min_instances_per_node <-
247+ rlang :: call2(" min" , arg_vals $ min_instances_per_node , expr(nrow(x )))
248+ }
249+
250+ # # -----------------------------------------------------------------------------
251+
236252 x
237253}
238254
@@ -242,14 +258,18 @@ check_args.boost_tree <- function(object) {
242258
243259 args <- lapply(object $ args , rlang :: eval_tidy )
244260
245- if (is.numeric(args $ trees ) && args $ trees < 0 )
261+ if (is.numeric(args $ trees ) && args $ trees < 0 ) {
246262 rlang :: abort(" `trees` should be >= 1." )
247- if (is.numeric(args $ sample_size ) && (args $ sample_size < 0 | args $ sample_size > 1 ))
263+ }
264+ if (is.numeric(args $ sample_size ) && (args $ sample_size < 0 | args $ sample_size > 1 )) {
248265 rlang :: abort(" `sample_size` should be within [0,1]." )
249- if (is.numeric(args $ tree_depth ) && args $ tree_depth < 0 )
266+ }
267+ if (is.numeric(args $ tree_depth ) && args $ tree_depth < 0 ) {
250268 rlang :: abort(" `tree_depth` should be >= 1." )
251- if (is.numeric(args $ min_n ) && args $ min_n < 0 )
269+ }
270+ if (is.numeric(args $ min_n ) && args $ min_n < 0 ) {
252271 rlang :: abort(" `min_n` should be >= 1." )
272+ }
253273
254274 invisible (object )
255275}
@@ -340,7 +360,7 @@ xgb_train <- function(
340360 max_depth = max_depth ,
341361 gamma = gamma ,
342362 colsample_bytree = colsample_bytree ,
343- min_child_weight = min_child_weight ,
363+ min_child_weight = min( min_child_weight , n ) ,
344364 subsample = subsample
345365 )
346366
@@ -516,7 +536,7 @@ C5.0_train <-
516536 fit_args <- other_args [names(other_args ) %in% f_names ]
517537
518538 ctrl <- call2(" C5.0Control" , .ns = " C50" )
519- ctrl $ minCases <- minCases
539+ ctrl $ minCases <- min( minCases , nrow( x ))
520540 ctrl $ sample <- sample
521541 ctrl <- rlang :: call_modify(ctrl , !!! ctrl_args )
522542
0 commit comments