@@ -218,6 +218,37 @@ test_that('submodel prediction', {
218218 )
219219})
220220
221+ test_that(' prediction with event_level' , {
222+
223+ skip_if_not_installed(" xgboost" )
224+ library(xgboost )
225+
226+ vars <- c(" female" , " tenure" , " total_charges" , " phone_service" , " monthly_charges" )
227+
228+ # event_level = "first"
229+ fit_1 <-
230+ boost_tree(trees = 20 , mode = " classification" ) %> %
231+ set_engine(" xgboost" ) %> %
232+ fit(churn ~ . , data = wa_churn [- (1 : 4 ), c(" churn" , vars )])
233+
234+ x <- xgboost :: xgb.DMatrix(as.matrix(wa_churn [1 : 4 , vars ]))
235+
236+ pred_xgb_1 <- predict(fit_1 $ fit , x )
237+ pred_res_1 <- predict(fit_1 , new_data = wa_churn [1 : 4 , vars ], type = " prob" )
238+ expect_equal(pred_res_1 [[" .pred_Yes" ]], pred_xgb_1 )
239+
240+ # event_level = "second"
241+ fit_2 <-
242+ boost_tree(trees = 20 , mode = " classification" ) %> %
243+ set_engine(" xgboost" , event_level = " second" ) %> %
244+ fit(churn ~ . , data = wa_churn [- (1 : 4 ), c(" churn" , vars )])
245+
246+ x <- xgboost :: xgb.DMatrix(as.matrix(wa_churn [1 : 4 , vars ]))
247+
248+ pred_xgb_2 <- predict(fit_2 $ fit , x )
249+ pred_res_2 <- predict(fit_2 , new_data = wa_churn [1 : 4 , vars ], type = " prob" )
250+ expect_equal(pred_res_2 [[" .pred_No" ]], pred_xgb_2 )
251+ })
221252
222253test_that(' default engine' , {
223254 skip_if_not_installed(" xgboost" )
0 commit comments