@@ -137,7 +137,7 @@ print.boost_tree <- function(x, ...) {
137137 cat(" Boosted Tree Model Specification (" , x $ mode , " )\n\n " , sep = " " )
138138 model_printer(x , ... )
139139
140- if (! is.null(x $ method $ fit $ args )) {
140+ if (! is.null(x $ method $ fit $ args )) {
141141 cat(" Model fit template:\n " )
142142 print(show_call(x ))
143143 }
@@ -211,14 +211,15 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
211211 x <- translate.default(x , engine , ... )
212212
213213 if (engine == " spark" ) {
214- if (x $ mode == " unknown" )
214+ if (x $ mode == " unknown" ) {
215215 stop(
216216 " For spark boosted trees models, the mode cannot be 'unknown' " ,
217217 " if the specification is to be translated." ,
218218 call. = FALSE
219219 )
220- else
220+ } else {
221221 x $ method $ fit $ args $ type <- x $ mode
222+ }
222223 }
223224 x
224225}
@@ -282,27 +283,33 @@ xgb_train <- function(
282283 }
283284 }
284285
285- if (is.data.frame(x ))
286+ if (is.data.frame(x )) {
286287 x <- as.matrix(x ) # maybe use model.matrix here?
288+ }
287289
288290 n <- nrow(x )
289291 p <- ncol(x )
290292
291- if (! inherits(x , " xgb.DMatrix" ))
293+ if (! inherits(x , " xgb.DMatrix" )) {
292294 x <- xgboost :: xgb.DMatrix(x , label = y , missing = NA )
293- else
295+ } else {
294296 xgboost :: setinfo(x , " label" , y )
297+ }
295298
296299 # translate `subsample` and `colsample_bytree` to be on (0, 1] if not
297- if (subsample > 1 )
300+ if (subsample > 1 ) {
298301 subsample <- subsample / n
299- if (subsample > 1 )
302+ }
303+ if (subsample > 1 ) {
300304 subsample <- 1
305+ }
301306
302- if (colsample_bytree > 1 )
307+ if (colsample_bytree > 1 ) {
303308 colsample_bytree <- colsample_bytree / p
304- if (colsample_bytree > 1 )
309+ }
310+ if (colsample_bytree > 1 ) {
305311 colsample_bytree <- 1
312+ }
306313
307314 arg_list <- list (
308315 eta = eta ,
@@ -321,18 +328,19 @@ xgb_train <- function(
321328 nrounds = nrounds ,
322329 objective = loss
323330 )
324- if (! is.null(num_class ))
331+ if (! is.null(num_class )) {
325332 main_args $ num_class <- num_class
333+ }
326334
327335 call <- make_call(fun = " xgb.train" , ns = " xgboost" , main_args )
328336
329337 # override or add some other args
330338 others <- list (... )
331339 others <-
332340 others [! (names(others ) %in% c(" data" , " weights" , " nrounds" , " num_class" , names(arg_list )))]
333- if (length(others ) > 0 )
334- for ( i in names( others ) )
335- call [[ i ]] <- others [[ i ]]
341+ if (length(others ) > 0 ) {
342+ call <- rlang :: call_modify( call , !!! others )
343+ }
336344
337345 eval_tidy(call , env = current_env())
338346}
@@ -348,7 +356,7 @@ xgb_pred <- function(object, newdata, ...) {
348356
349357 x = switch (
350358 object $ params $ objective ,
351- " reg:linear" = , " reg:logistic" = , " binary:logistic" = res ,
359+ " reg:linear" = , " reg:logistic" = , " binary:logistic" = res ,
352360 " binary:logitraw" = stats :: binomial()$ linkinv(res ),
353361 " multi:softprob" = matrix (res , ncol = object $ params $ num_class , byrow = TRUE ),
354362 res
@@ -360,11 +368,13 @@ xgb_pred <- function(object, newdata, ...) {
360368# ' @export
361369multi_predict._xgb.Booster <-
362370 function (object , new_data , type = NULL , trees = NULL , ... ) {
363- if (any(names(enquos(... )) == " newdata" ))
371+ if (any(names(enquos(... )) == " newdata" )) {
364372 stop(" Did you mean to use `new_data` instead of `newdata`?" , call. = FALSE )
373+ }
365374
366- if (is.null(trees ))
375+ if (is.null(trees )) {
367376 trees <- object $ fit $ nIter
377+ }
368378 trees <- sort(trees )
369379
370380 if (is.null(type )) {
@@ -374,9 +384,8 @@ multi_predict._xgb.Booster <-
374384 type <- " numeric"
375385 }
376386
377- res <-
378- map_df(trees , xgb_by_tree , object = object ,
379- new_data = new_data , type = type , ... )
387+ res <- map_df(trees , xgb_by_tree , object = object , new_data = new_data ,
388+ type = type , ... )
380389 res <- arrange(res , .row , trees )
381390 res <- split(res [, - 1 ], res $ .row )
382391 names(res ) <- NULL
@@ -449,19 +458,18 @@ C5.0_train <-
449458 ctrl <- call2(" C5.0Control" , .ns = " C50" )
450459 ctrl $ minCases <- minCases
451460 ctrl $ sample <- sample
452- for (i in names(ctrl_args ))
453- ctrl [[i ]] <- ctrl_args [[i ]]
461+ ctrl <- rlang :: call_modify(ctrl , !!! ctrl_args )
454462
455463 fit_call <- call2(" C5.0" , .ns = " C50" )
456464 fit_call $ x <- expr(x )
457465 fit_call $ y <- expr(y )
458466 fit_call $ trials <- trials
459467 fit_call $ control <- ctrl
460- if (! is.null(weights ))
468+ if (! is.null(weights )) {
461469 fit_call $ weights <- quote(weights )
470+ }
471+ fit_call <- rlang :: call_modify(fit_call , !!! fit_args )
462472
463- for (i in names(fit_args ))
464- fit_call [[i ]] <- fit_args [[i ]]
465473 eval_tidy(fit_call )
466474 }
467475
0 commit comments