Skip to content

Commit a7b31a8

Browse files
committed
add test for prediction with event_level
1 parent 6244c1d commit a7b31a8

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,37 @@ test_that('submodel prediction', {
218218
)
219219
})
220220

221+
test_that('prediction with event_level', {
222+
223+
skip_if_not_installed("xgboost")
224+
library(xgboost)
225+
226+
vars <- c("female", "tenure", "total_charges", "phone_service", "monthly_charges")
227+
228+
# event_level = "first"
229+
fit_1 <-
230+
boost_tree(trees = 20, mode = "classification") %>%
231+
set_engine("xgboost") %>%
232+
fit(churn ~ ., data = wa_churn[-(1:4), c("churn", vars)])
233+
234+
x <- xgboost::xgb.DMatrix(as.matrix(wa_churn[1:4, vars]))
235+
236+
pred_xgb_1 <- predict(fit_1$fit, x)
237+
pred_res_1 <- predict(fit_1, new_data = wa_churn[1:4, vars], type = "prob")
238+
expect_equal(pred_res_1[[".pred_Yes"]], pred_xgb_1)
239+
240+
# event_level = "second"
241+
fit_2 <-
242+
boost_tree(trees = 20, mode = "classification") %>%
243+
set_engine("xgboost", event_level = "second") %>%
244+
fit(churn ~ ., data = wa_churn[-(1:4), c("churn", vars)])
245+
246+
x <- xgboost::xgb.DMatrix(as.matrix(wa_churn[1:4, vars]))
247+
248+
pred_xgb_2 <- predict(fit_2$fit, x)
249+
pred_res_2 <- predict(fit_2, new_data = wa_churn[1:4, vars], type = "prob")
250+
expect_equal(pred_res_2[[".pred_No"]], pred_xgb_2)
251+
})
221252

222253
test_that('default engine', {
223254
skip_if_not_installed("xgboost")

0 commit comments

Comments
 (0)