@@ -161,6 +161,8 @@ translate.rand_forest <- function(x, engine = x$engine, ...) {
161161
162162 x <- translate.default(x , engine , ... )
163163
164+ # # -----------------------------------------------------------------------------
165+
164166 # slightly cleaner code using
165167 arg_vals <- x $ method $ fit $ args
166168
@@ -185,14 +187,40 @@ translate.rand_forest <- function(x, engine = x$engine, ...) {
185187
186188 # add checks to error trap or change things for this method
187189 if (engine == " ranger" ) {
188- if (any(names(arg_vals ) == " importance" ))
189- if (isTRUE(is.logical(quo_get_expr(arg_vals $ importance ))))
190+
191+ if (any(names(arg_vals ) == " importance" )) {
192+ if (isTRUE(is.logical(quo_get_expr(arg_vals $ importance )))) {
190193 rlang :: abort(" `importance` should be a character value. See ?ranger::ranger." )
194+ }
195+ }
191196 # unless otherwise specified, classification models are probability forests
192- if (x $ mode == " classification" && ! any(names(arg_vals ) == " probability" ))
197+ if (x $ mode == " classification" && ! any(names(arg_vals ) == " probability" )) {
193198 arg_vals $ probability <- TRUE
199+ }
200+ }
201+
202+ # # -----------------------------------------------------------------------------
203+ # Protect some arguments based on data dimensions
194204
205+ if (any(names(arg_vals ) == " mtry" )) {
206+ arg_vals $ mtry <- rlang :: call2(" min" , arg_vals $ mtry , expr(ncol(x )))
195207 }
208+
209+ if (any(names(arg_vals ) == " min.node.size" )) {
210+ arg_vals $ min.node.size <-
211+ rlang :: call2(" min" , arg_vals $ min.node.size , expr(nrow(x )))
212+ }
213+ if (any(names(arg_vals ) == " nodesize" )) {
214+ arg_vals $ nodesize <-
215+ rlang :: call2(" min" , arg_vals $ nodesize , expr(nrow(x )))
216+ }
217+ if (any(names(arg_vals ) == " min_instances_per_node" )) {
218+ arg_vals $ min_instances_per_node <-
219+ rlang :: call2(" min" , arg_vals $ min_instances_per_node , expr(nrow(x )))
220+ }
221+
222+ # # -----------------------------------------------------------------------------
223+
196224 x $ method $ fit $ args <- arg_vals
197225
198226 x
0 commit comments