@@ -28,6 +28,12 @@ lr_fit_2 <-
2828 set_engine(" glm" ) %> %
2929 fit(Ozone ~ . , data = class_dat2 )
3030
31+ lr_fit_3 <-
32+ mlp(mode = ' classification' ) %> %
33+ set_engine(" nnet" ) %> %
34+ fit(Ozone ~ . , data = class_dat2 [1 : 5 , ])
35+
36+
3137# ------------------------------------------------------------------------------
3238
3339test_that(' regression predictions' , {
@@ -54,8 +60,11 @@ test_that('non-standard levels', {
5460
5561 expect_true(is_tibble(predict(lr_fit_2 , new_data = class_dat2 [1 : 5 ,- 1 ], type = " prob" )))
5662 expect_true(is_tibble(parsnip ::: predict_classprob.model_fit(lr_fit_2 , new_data = class_dat2 [1 : 5 ,- 1 ])))
63+ final_colnames <- c(" .pred_2low" , " .pred_high+values" )
5764 expect_equal(names(predict(lr_fit_2 , new_data = class_dat2 [1 : 5 ,- 1 ], type = " prob" )),
58- c(" .pred_2low" , " .pred_high+values" ))
65+ final_colnames )
66+ expect_equal(names(predict(lr_fit_3 , new_data = class_dat2 , type = ' prob' )),
67+ final_colnames )
5968 expect_equal(names(parsnip ::: predict_classprob.model_fit(lr_fit_2 , new_data = class_dat2 [1 : 5 ,- 1 ])),
6069 c(" 2low" , " high+values" ))
6170})
0 commit comments