Skip to content

Commit c9d21e1

Browse files
committed
enable xgboost to use sparse X for tidymodels/tidymodels#42
1 parent 9e30b6b commit c9d21e1

File tree

4 files changed

+112
-33
lines changed

4 files changed

+112
-33
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,5 @@ Suggests:
5151
MASS,
5252
nlme,
5353
modeldata,
54-
liquidSVM
54+
liquidSVM,
55+
Matrix

R/boost_tree.R

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,8 @@ xgb_train <- function(
290290
min_child_weight = 1, gamma = 0, subsample = 1, validation = 0,
291291
early_stop = NULL, ...) {
292292

293-
if (length(levels(y)) > 2) {
294-
num_class <- length(levels(y))
295-
} else {
296-
num_class <- NULL
297-
}
293+
num_class <- length(levels(y))
294+
298295
if (!is.numeric(validation) || validation < 0 || validation >= 1) {
299296
rlang::abort("`validation` should be on [0, 1).")
300297
}
@@ -311,36 +308,17 @@ xgb_train <- function(
311308
if (is.numeric(y)) {
312309
loss <- "reg:squarederror"
313310
} else {
314-
lvl <- levels(y)
315-
y <- as.numeric(y) - 1
316-
if (length(lvl) == 2) {
311+
if (num_class == 2) {
317312
loss <- "binary:logistic"
318313
} else {
319314
loss <- "multi:softprob"
320315
}
321316
}
322317

323-
if (is.data.frame(x)) {
324-
x <- as.matrix(x) # maybe use model.matrix here?
325-
}
326-
327318
n <- nrow(x)
328319
p <- ncol(x)
329320

330-
if (!inherits(x, "xgb.DMatrix")) {
331-
if (validation > 0) {
332-
trn_index <- sample(1:n, size = floor(n * validation) + 1)
333-
wlist <-
334-
list(validation = xgboost::xgb.DMatrix(x[-trn_index, ], label = y[-trn_index], missing = NA))
335-
x <- xgboost::xgb.DMatrix(x[trn_index, ], label = y[trn_index], missing = NA)
336-
337-
} else {
338-
x <- xgboost::xgb.DMatrix(x, label = y, missing = NA)
339-
wlist <- list(training = x)
340-
}
341-
} else {
342-
xgboost::setinfo(x, "label", y)
343-
}
321+
x <- as_xgb_data(x, y, validation)
344322

345323
# translate `subsample` and `colsample_bytree` to be on (0, 1] if not
346324
if (subsample > 1) {
@@ -366,17 +344,15 @@ xgb_train <- function(
366344
subsample = subsample
367345
)
368346

369-
# eval if contains expressions?
370-
371347
main_args <- list(
372-
data = quote(x),
373-
watchlist = quote(wlist),
348+
data = quote(x$data),
349+
watchlist = quote(x$watchlist),
374350
params = arg_list,
375351
nrounds = nrounds,
376352
objective = loss,
377353
early_stopping_rounds = early_stop
378354
)
379-
if (!is.null(num_class)) {
355+
if (!is.null(num_class) && num_class > 2) {
380356
main_args$num_class <- num_class
381357
}
382358

@@ -399,7 +375,7 @@ xgb_train <- function(
399375
#' @importFrom stats binomial
400376
xgb_pred <- function(object, newdata, ...) {
401377
if (!inherits(newdata, "xgb.DMatrix")) {
402-
newdata <- as.matrix(newdata)
378+
newdata <- as_matrix(newdata)
403379
newdata <- xgboost::xgb.DMatrix(data = newdata, missing = NA)
404380
}
405381

@@ -415,6 +391,37 @@ xgb_pred <- function(object, newdata, ...) {
415391
x
416392
}
417393

394+
395+
as_xgb_data <- function(x, y, validation = 0, ...) {
396+
lvls <- levels(y)
397+
n <- nrow(x)
398+
399+
if (is.data.frame(x)) {
400+
x <- as.matrix(x)
401+
}
402+
403+
if (is.factor(y)) {
404+
y <- as.numeric(y) - 1
405+
}
406+
407+
if (!inherits(x, "xgb.DMatrix")) {
408+
if (validation > 0) {
409+
trn_index <- sample(1:n, size = floor(n * (1 - validation)) + 1)
410+
wlist <-
411+
list(validation = xgboost::xgb.DMatrix(x[-trn_index, ], label = y[-trn_index], missing = NA))
412+
dat <- xgboost::xgb.DMatrix(x[trn_index, ], label = y[trn_index], missing = NA)
413+
414+
} else {
415+
dat <- xgboost::xgb.DMatrix(x, label = y, missing = NA)
416+
wlist <- list(training = dat)
417+
}
418+
} else {
419+
dat <- xgboost::setinfo(x, "label", y)
420+
wlist <- list(training = dat)
421+
}
422+
423+
list(data = dat, watchlist = wlist)
424+
}
418425
#' @importFrom purrr map_df
419426
#' @export
420427
#' @rdname multi_predict

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,70 @@ test_that('early stopping', {
286286
regex = "`early_stop` should be on"
287287
)
288288
})
289+
290+
291+
## -----------------------------------------------------------------------------
292+
293+
test_that('xgboost data conversion', {
294+
skip_if_not_installed("xgboost")
295+
296+
mtcar_x <- mtcars[, -1]
297+
mtcar_mat <- as.matrix(mtcar_x)
298+
mtcar_smat <- Matrix::Matrix(mtcar_mat, sparse = TRUE)
299+
300+
expect_error(from_df <- parsnip:::as_xgb_data(mtcar_x, mtcars$mpg), regexp = NA)
301+
expect_true(inherits(from_df$data, "xgb.DMatrix"))
302+
expect_true(inherits(from_df$watchlist$training, "xgb.DMatrix"))
303+
304+
expect_error(from_mat <- parsnip:::as_xgb_data(mtcar_mat, mtcars$mpg), regexp = NA)
305+
expect_true(inherits(from_mat$data, "xgb.DMatrix"))
306+
expect_true(inherits(from_mat$watchlist$training, "xgb.DMatrix"))
307+
308+
expect_error(from_sparse <- parsnip:::as_xgb_data(mtcar_smat, mtcars$mpg), regexp = NA)
309+
expect_true(inherits(from_mat$data, "xgb.DMatrix"))
310+
expect_true(inherits(from_mat$watchlist$training, "xgb.DMatrix"))
311+
312+
expect_error(from_df <- parsnip:::as_xgb_data(mtcar_x, mtcars$mpg, validation = .1), regexp = NA)
313+
expect_true(inherits(from_df$data, "xgb.DMatrix"))
314+
expect_true(inherits(from_df$watchlist$validation, "xgb.DMatrix"))
315+
expect_true(nrow(from_df$data) > nrow(from_df$watchlist$validation))
316+
317+
expect_error(from_mat <- parsnip:::as_xgb_data(mtcar_mat, mtcars$mpg, validation = .1), regexp = NA)
318+
expect_true(inherits(from_mat$data, "xgb.DMatrix"))
319+
expect_true(inherits(from_mat$watchlist$validation, "xgb.DMatrix"))
320+
expect_true(nrow(from_mat$data) > nrow(from_mat$watchlist$validation))
321+
322+
expect_error(from_sparse <- parsnip:::as_xgb_data(mtcar_smat, mtcars$mpg, validation = .1), regexp = NA)
323+
expect_true(inherits(from_mat$data, "xgb.DMatrix"))
324+
expect_true(inherits(from_mat$watchlist$validation, "xgb.DMatrix"))
325+
expect_true(nrow(from_sparse$data) > nrow(from_sparse$watchlist$validation))
326+
327+
})
328+
329+
330+
test_that('xgboost data and sparse matrices', {
331+
skip_if_not_installed("xgboost")
332+
333+
mtcar_x <- mtcars[, -1]
334+
mtcar_mat <- as.matrix(mtcar_x)
335+
mtcar_smat <- Matrix::Matrix(mtcar_mat, sparse = TRUE)
336+
337+
xgb_spec <-
338+
boost_tree(trees = 10) %>%
339+
set_engine("xgboost") %>%
340+
set_mode("regression")
341+
342+
set.seed(1)
343+
from_df <- xgb_spec %>% fit_xy(mtcar_x, mtcars$mpg)
344+
set.seed(1)
345+
from_mat <- xgb_spec %>% fit_xy(mtcar_mat, mtcars$mpg)
346+
set.seed(1)
347+
from_sparse <- xgb_spec %>% fit_xy(mtcar_smat, mtcars$mpg)
348+
349+
expect_equal(from_df$fit, from_mat$fit)
350+
expect_equal(from_df$fit, from_sparse$fit)
351+
352+
})
353+
354+
355+

tests/testthat/test_convert_data.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,4 +626,8 @@ test_that("convert to matrix", {
626626
expect_true(inherits(parsnip::as_matrix(mtcars), "matrix"))
627627
expect_true(inherits(parsnip::as_matrix(tibble::as_tibble(mtcars)), "matrix"))
628628
expect_true(inherits(parsnip::as_matrix(as.matrix(mtcars)), "matrix"))
629+
expect_true(
630+
inherits(parsnip::as_matrix(Matrix::Matrix(as.matrix(mtcars), sparse = TRUE)),
631+
"dgCMatrix")
632+
)
629633
})

0 commit comments

Comments
 (0)