Skip to content

Commit 9521709

Browse files
committed
test for fit and prediction
1 parent dc5f046 commit 9521709

File tree

1 file changed

+45
-56
lines changed

1 file changed

+45
-56
lines changed

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 45 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -218,38 +218,6 @@ test_that('submodel prediction', {
218218
)
219219
})
220220

221-
test_that('prediction with event_level', {
222-
223-
skip_if_not_installed("xgboost")
224-
library(xgboost)
225-
226-
vars <- c("female", "tenure", "total_charges", "phone_service", "monthly_charges")
227-
228-
# event_level = "first"
229-
fit_1 <-
230-
boost_tree(trees = 20, mode = "classification") %>%
231-
set_engine("xgboost") %>%
232-
fit(churn ~ ., data = wa_churn[-(1:4), c("churn", vars)])
233-
234-
x <- xgboost::xgb.DMatrix(as.matrix(wa_churn[1:4, vars]))
235-
236-
pred_xgb_1 <- predict(fit_1$fit, x)
237-
pred_res_1 <- predict(fit_1, new_data = wa_churn[1:4, vars], type = "prob")
238-
expect_equal(pred_res_1[[".pred_Yes"]], pred_xgb_1)
239-
240-
# event_level = "second"
241-
fit_2 <-
242-
boost_tree(trees = 20, mode = "classification") %>%
243-
set_engine("xgboost", event_level = "second") %>%
244-
fit(churn ~ ., data = wa_churn[-(1:4), c("churn", vars)])
245-
246-
x <- xgboost::xgb.DMatrix(as.matrix(wa_churn[1:4, vars]))
247-
248-
pred_xgb_2 <- predict(fit_2$fit, x)
249-
pred_res_2 <- predict(fit_2, new_data = wa_churn[1:4, vars], type = "prob")
250-
expect_equal(pred_res_2[[".pred_No"]], pred_xgb_2)
251-
})
252-
253221
test_that('default engine', {
254222
skip_if_not_installed("xgboost")
255223
expect_warning(
@@ -453,43 +421,64 @@ test_that('argument checks for data dimensions', {
453421

454422
})
455423

456-
test_that("set `event_level` as engine-specific argument", {
424+
test_that("fit and prediction with `event_level`", {
457425

458426
skip_if_not_installed("xgboost")
459427

460428
data(penguins, package = "modeldata")
461429
penguins <- na.omit(penguins[, -c(1:2)])
462430

463-
spec <-
464-
boost_tree(trees = 10, tree_depth = 3) %>%
465-
set_engine(
466-
"xgboost",
467-
eval_metric = "aucpr",
468-
event_level = "second",
469-
verbose = 1
470-
) %>%
471-
set_mode("classification")
431+
train_x <- as.matrix(penguins[-(1:4), -5])
432+
train_y_1 <- -as.numeric(penguins$sex[-(1:4)]) + 2
433+
train_y_2 <- as.numeric(penguins$sex[-(1:4)]) - 1
434+
435+
x_pred <- xgboost::xgb.DMatrix(as.matrix(penguins[1:4, -5]))
472436

437+
# event_level = "first"
473438
set.seed(24)
474-
fit_p <- spec %>% fit(sex ~ ., data = penguins)
439+
fit_p_1 <- boost_tree(trees = 10) %>%
440+
set_engine("xgboost", eval_metric = "auc"
441+
# event_level = "first" is the default
442+
) %>%
443+
set_mode("classification") %>%
444+
fit(sex ~ ., data = penguins[-(1:4), ])
475445

476-
penguins_x <- as.matrix(penguins[, -5])
477-
penguins_y <- as.numeric(penguins$sex) - 1
478-
xgbmat <- xgb.DMatrix(data = penguins_x, label = penguins_y)
446+
xgbmat_train_1 <- xgb.DMatrix(data = train_x, label = train_y_1)
479447

480448
set.seed(24)
481-
fit_xgb <- xgboost::xgb.train(data = xgbmat,
482-
params = list(eta = 0.3, max_depth = 3,
483-
gamma = 0, colsample_bytree = 1,
484-
min_child_weight = 1,
485-
subsample = 1),
449+
fit_xgb_1 <- xgboost::xgb.train(data = xgbmat_train_1,
486450
nrounds = 10,
487-
watchlist = list("training" = xgbmat),
451+
watchlist = list("training" = xgbmat_train_1),
488452
objective = "binary:logistic",
489-
verbose = 1,
490-
eval_metric = "aucpr",
491-
nthread = 1)
453+
eval_metric = "auc")
454+
455+
expect_equal(fit_p_1$fit$evaluation_log, fit_xgb_1$evaluation_log)
456+
457+
pred_xgb_1 <- predict(fit_xgb_1, x_pred)
458+
pred_p_1 <- predict(fit_p_1, new_data = penguins[1:4, ], type = "prob")
459+
expect_equal(pred_p_1[[".pred_female"]], pred_xgb_1)
460+
461+
# event_level = "second"
462+
set.seed(24)
463+
fit_p_2 <- boost_tree(trees = 10) %>%
464+
set_engine("xgboost", eval_metric = "auc",
465+
event_level = "second") %>%
466+
set_mode("classification") %>%
467+
fit(sex ~ ., data = penguins[-(1:4), ])
468+
469+
xgbmat_train_2 <- xgb.DMatrix(data = train_x, label = train_y_2)
470+
471+
set.seed(24)
472+
fit_xgb_2 <- xgboost::xgb.train(data = xgbmat_train_2,
473+
nrounds = 10,
474+
watchlist = list("training" = xgbmat_train_2),
475+
objective = "binary:logistic",
476+
eval_metric = "auc")
477+
478+
expect_equal(fit_p_2$fit$evaluation_log, fit_xgb_2$evaluation_log)
492479

493-
expect_equal(fit_p$fit$evaluation_log, fit_xgb$evaluation_log)
480+
pred_xgb_2 <- predict(fit_xgb_2, x_pred)
481+
pred_p_2 <- predict(fit_p_2, new_data = penguins[1:4, ], type = "prob")
482+
expect_equal(pred_p_2[[".pred_male"]], pred_xgb_2)
494483

495484
})

0 commit comments

Comments
 (0)