Skip to content

Commit 228a6dc

Browse files
authored
Merge pull request #466 from tidymodels/pred-event-level
Make predictions respect `event_level` for xgboost
2 parents b79ff6d + 9521709 commit 228a6dc

File tree

4 files changed

+78
-34
lines changed

4 files changed

+78
-34
lines changed

R/boost_tree.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,17 @@ as_xgb_data <- function(x, y, validation = 0, event_level = "first", ...) {
472472

473473
list(data = dat, watchlist = wlist)
474474
}
475+
476+
get_event_level <- function(model_spec){
477+
if ("event_level" %in% names(model_spec$eng_args)) {
478+
event_level <- get_expr(model_spec$eng_args$event_level)
479+
} else {
480+
# "first" is the default for as_xgb_data() and xgb_train()
481+
event_level <- "first"
482+
}
483+
event_level
484+
}
485+
475486
#' @importFrom purrr map_df
476487
#' @export
477488
#' @rdname multi_predict

R/boost_tree_data.R

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,12 @@ set_pred(
158158
pre = NULL,
159159
post = function(x, object) {
160160
if (is.vector(x)) {
161-
x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1])
161+
event_level <- get_event_level(object$spec)
162+
if (event_level == "first") {
163+
x <- ifelse(x >= 0.5, object$lvl[1], object$lvl[2])
164+
} else {
165+
x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1])
166+
}
162167
} else {
163168
x <- object$lvl[apply(x, 1, which.max)]
164169
}
@@ -178,7 +183,12 @@ set_pred(
178183
pre = NULL,
179184
post = function(x, object) {
180185
if (is.vector(x)) {
181-
x <- tibble(v1 = 1 - x, v2 = x)
186+
event_level <- get_event_level(object$spec)
187+
if (event_level == "first") {
188+
x <- tibble(v1 = x, v2 = 1 - x)
189+
} else {
190+
x <- tibble(v1 = 1 - x, v2 = x)
191+
}
182192
} else {
183193
x <- as_tibble(x, .name_repair = "minimal")
184194
}

man/boost_tree.Rd

Lines changed: 9 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,14 @@ test_that('submodel prediction', {
210210

211211
mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], trees = 5, type = "prob")
212212
mp_res <- do.call("rbind", mp_res$.pred)
213-
expect_equal(mp_res[[".pred_No"]], pred_class)
213+
expect_equal(mp_res[[".pred_Yes"]], pred_class)
214214

215215
expect_error(
216216
multi_predict(class_fit, newdata = wa_churn[1:4, vars], trees = 5, type = "prob"),
217217
"Did you mean"
218218
)
219219
})
220220

221-
222221
test_that('default engine', {
223222
skip_if_not_installed("xgboost")
224223
expect_warning(
@@ -422,43 +421,64 @@ test_that('argument checks for data dimensions', {
422421

423422
})
424423

425-
test_that("set `event_level` as engine-specific argument", {
424+
test_that("fit and prediction with `event_level`", {
426425

427426
skip_if_not_installed("xgboost")
428427

429428
data(penguins, package = "modeldata")
430429
penguins <- na.omit(penguins[, -c(1:2)])
431430

432-
spec <-
433-
boost_tree(trees = 10, tree_depth = 3) %>%
434-
set_engine(
435-
"xgboost",
436-
eval_metric = "aucpr",
437-
event_level = "second",
438-
verbose = 1
439-
) %>%
440-
set_mode("classification")
431+
train_x <- as.matrix(penguins[-(1:4), -5])
432+
train_y_1 <- -as.numeric(penguins$sex[-(1:4)]) + 2
433+
train_y_2 <- as.numeric(penguins$sex[-(1:4)]) - 1
441434

435+
x_pred <- xgboost::xgb.DMatrix(as.matrix(penguins[1:4, -5]))
436+
437+
# event_level = "first"
442438
set.seed(24)
443-
fit_p <- spec %>% fit(sex ~ ., data = penguins)
439+
fit_p_1 <- boost_tree(trees = 10) %>%
440+
set_engine("xgboost", eval_metric = "auc"
441+
# event_level = "first" is the default
442+
) %>%
443+
set_mode("classification") %>%
444+
fit(sex ~ ., data = penguins[-(1:4), ])
444445

445-
penguins_x <- as.matrix(penguins[, -5])
446-
penguins_y <- as.numeric(penguins$sex) - 1
447-
xgbmat <- xgb.DMatrix(data = penguins_x, label = penguins_y)
446+
xgbmat_train_1 <- xgb.DMatrix(data = train_x, label = train_y_1)
448447

449448
set.seed(24)
450-
fit_xgb <- xgboost::xgb.train(data = xgbmat,
451-
params = list(eta = 0.3, max_depth = 3,
452-
gamma = 0, colsample_bytree = 1,
453-
min_child_weight = 1,
454-
subsample = 1),
449+
fit_xgb_1 <- xgboost::xgb.train(data = xgbmat_train_1,
455450
nrounds = 10,
456-
watchlist = list("training" = xgbmat),
451+
watchlist = list("training" = xgbmat_train_1),
457452
objective = "binary:logistic",
458-
verbose = 1,
459-
eval_metric = "aucpr",
460-
nthread = 1)
453+
eval_metric = "auc")
454+
455+
expect_equal(fit_p_1$fit$evaluation_log, fit_xgb_1$evaluation_log)
456+
457+
pred_xgb_1 <- predict(fit_xgb_1, x_pred)
458+
pred_p_1 <- predict(fit_p_1, new_data = penguins[1:4, ], type = "prob")
459+
expect_equal(pred_p_1[[".pred_female"]], pred_xgb_1)
460+
461+
# event_level = "second"
462+
set.seed(24)
463+
fit_p_2 <- boost_tree(trees = 10) %>%
464+
set_engine("xgboost", eval_metric = "auc",
465+
event_level = "second") %>%
466+
set_mode("classification") %>%
467+
fit(sex ~ ., data = penguins[-(1:4), ])
468+
469+
xgbmat_train_2 <- xgb.DMatrix(data = train_x, label = train_y_2)
470+
471+
set.seed(24)
472+
fit_xgb_2 <- xgboost::xgb.train(data = xgbmat_train_2,
473+
nrounds = 10,
474+
watchlist = list("training" = xgbmat_train_2),
475+
objective = "binary:logistic",
476+
eval_metric = "auc")
477+
478+
expect_equal(fit_p_2$fit$evaluation_log, fit_xgb_2$evaluation_log)
461479

462-
expect_equal(fit_p$fit$evaluation_log, fit_xgb$evaluation_log)
480+
pred_xgb_2 <- predict(fit_xgb_2, x_pred)
481+
pred_p_2 <- predict(fit_p_2, new_data = penguins[1:4, ], type = "prob")
482+
expect_equal(pred_p_2[[".pred_male"]], pred_xgb_2)
463483

464484
})

0 commit comments

Comments
 (0)