@@ -290,11 +290,8 @@ xgb_train <- function(
290290 min_child_weight = 1 , gamma = 0 , subsample = 1 , validation = 0 ,
291291 early_stop = NULL , ... ) {
292292
293- if (length(levels(y )) > 2 ) {
294- num_class <- length(levels(y ))
295- } else {
296- num_class <- NULL
297- }
293+ num_class <- length(levels(y ))
294+
298295 if (! is.numeric(validation ) || validation < 0 || validation > = 1 ) {
299296 rlang :: abort(" `validation` should be on [0, 1)." )
300297 }
@@ -311,36 +308,17 @@ xgb_train <- function(
311308 if (is.numeric(y )) {
312309 loss <- " reg:squarederror"
313310 } else {
314- lvl <- levels(y )
315- y <- as.numeric(y ) - 1
316- if (length(lvl ) == 2 ) {
311+ if (num_class == 2 ) {
317312 loss <- " binary:logistic"
318313 } else {
319314 loss <- " multi:softprob"
320315 }
321316 }
322317
323- if (is.data.frame(x )) {
324- x <- as.matrix(x ) # maybe use model.matrix here?
325- }
326-
327318 n <- nrow(x )
328319 p <- ncol(x )
329320
330- if (! inherits(x , " xgb.DMatrix" )) {
331- if (validation > 0 ) {
332- trn_index <- sample(1 : n , size = floor(n * validation ) + 1 )
333- wlist <-
334- list (validation = xgboost :: xgb.DMatrix(x [- trn_index , ], label = y [- trn_index ], missing = NA ))
335- x <- xgboost :: xgb.DMatrix(x [trn_index , ], label = y [trn_index ], missing = NA )
336-
337- } else {
338- x <- xgboost :: xgb.DMatrix(x , label = y , missing = NA )
339- wlist <- list (training = x )
340- }
341- } else {
342- xgboost :: setinfo(x , " label" , y )
343- }
321+ x <- as_xgb_data(x , y , validation )
344322
345323 # translate `subsample` and `colsample_bytree` to be on (0, 1] if not
346324 if (subsample > 1 ) {
@@ -366,17 +344,15 @@ xgb_train <- function(
366344 subsample = subsample
367345 )
368346
369- # eval if contains expressions?
370-
371347 main_args <- list (
372- data = quote(x ),
373- watchlist = quote(wlist ),
348+ data = quote(x $ data ),
349+ watchlist = quote(x $ watchlist ),
374350 params = arg_list ,
375351 nrounds = nrounds ,
376352 objective = loss ,
377353 early_stopping_rounds = early_stop
378354 )
379- if (! is.null(num_class )) {
355+ if (! is.null(num_class ) && num_class > 2 ) {
380356 main_args $ num_class <- num_class
381357 }
382358
@@ -399,7 +375,7 @@ xgb_train <- function(
399375# ' @importFrom stats binomial
400376xgb_pred <- function (object , newdata , ... ) {
401377 if (! inherits(newdata , " xgb.DMatrix" )) {
402- newdata <- as.matrix (newdata )
378+ newdata <- maybe_matrix (newdata )
403379 newdata <- xgboost :: xgb.DMatrix(data = newdata , missing = NA )
404380 }
405381
@@ -415,6 +391,37 @@ xgb_pred <- function(object, newdata, ...) {
415391 x
416392}
417393
394+
395+ as_xgb_data <- function (x , y , validation = 0 , ... ) {
396+ lvls <- levels(y )
397+ n <- nrow(x )
398+
399+ if (is.data.frame(x )) {
400+ x <- as.matrix(x )
401+ }
402+
403+ if (is.factor(y )) {
404+ y <- as.numeric(y ) - 1
405+ }
406+
407+ if (! inherits(x , " xgb.DMatrix" )) {
408+ if (validation > 0 ) {
409+ trn_index <- sample(1 : n , size = floor(n * (1 - validation )) + 1 )
410+ wlist <-
411+ list (validation = xgboost :: xgb.DMatrix(x [- trn_index , ], label = y [- trn_index ], missing = NA ))
412+ dat <- xgboost :: xgb.DMatrix(x [trn_index , ], label = y [trn_index ], missing = NA )
413+
414+ } else {
415+ dat <- xgboost :: xgb.DMatrix(x , label = y , missing = NA )
416+ wlist <- list (training = dat )
417+ }
418+ } else {
419+ dat <- xgboost :: setinfo(x , " label" , y )
420+ wlist <- list (training = dat )
421+ }
422+
423+ list (data = dat , watchlist = wlist )
424+ }
418425# ' @importFrom purrr map_df
419426# ' @export
420427# ' @rdname multi_predict
0 commit comments