@@ -43,6 +43,50 @@ test_that('classification predictions', {
4343 c(" .pred_high" , " .pred_low" ))
4444})
4545
46+
47+ test_that(' ordinal classification predictions' , {
48+ skip_if_not_installed(" modeldata" )
49+ skip_if_not_installed(" rpart" )
50+
51+ set.seed(382 )
52+ dat_tr <-
53+ modeldata :: sim_multinomial(
54+ 200 ,
55+ ~ - 0.5 + 0.6 * abs(A ),
56+ ~ ifelse(A > 0 & B > 0 , 1.0 + 0.2 * A / B , - 2 ),
57+ ~ - 0.6 * A + 0.50 * B - A * B ) %> %
58+ dplyr :: mutate(class = as.ordered(class ))
59+ dat_te <-
60+ modeldata :: sim_multinomial(
61+ 5 ,
62+ ~ - 0.5 + 0.6 * abs(A ),
63+ ~ ifelse(A > 0 & B > 0 , 1.0 + 0.2 * A / B , - 2 ),
64+ ~ - 0.6 * A + 0.50 * B - A * B ) %> %
65+ dplyr :: mutate(class = as.ordered(class ))
66+
67+ # ##
68+
69+ mod_f_fit <-
70+ decision_tree() %> %
71+ set_mode(" classification" ) %> %
72+ fit(class ~ . , data = dat_tr )
73+ expect_true(" ordered" %in% names(mod_f_fit ))
74+ mod_f_pred <- predict(mod_f_fit , dat_te )
75+ expect_true(is.ordered(mod_f_pred $ .pred_class ))
76+
77+ # ##
78+
79+ mod_xy_fit <-
80+ decision_tree() %> %
81+ set_mode(" classification" ) %> %
82+ fit_xy(x = dat_tr %> % dplyr :: select(- class ), dat_tr $ class )
83+
84+ expect_true(" ordered" %in% names(mod_xy_fit ))
85+ mod_xy_pred <- predict(mod_xy_fit , dat_te )
86+ expect_true(is.ordered(mod_f_pred $ .pred_class ))
87+ })
88+
89+
4690test_that(' non-standard levels' , {
4791 expect_true(is_tibble(predict(lr_fit , new_data = class_dat [1 : 5 ,- 1 ])))
4892 expect_true(is.factor(parsnip ::: predict_class.model_fit(lr_fit , new_data = class_dat [1 : 5 ,- 1 ])))
0 commit comments