66# ' [fit()] and `new_data` contains the outcome column, a `.resid` column is
77# ' also added.
88# '
9- # ' For classification models, the results include a column called `.pred_class`
10- # ' as well as class probability columns named `.pred_{level}`.
9+ # ' For classification models, the results can include a column called
10+ # ' `.pred_class` as well as class probability columns named `.pred_{level}`.
11+ # ' This depends on what type of prediction types are available for the model.
1112# ' @param x A `model_fit` object produced by [fit()] or [fit_xy()].
1213# ' @param new_data A data frame or matrix.
1314# ' @param ... Not currently used.
5657# '
5758augment.model_fit <- function (x , new_data , ... ) {
5859 if (x $ spec $ mode == " regression" ) {
60+ check_spec_pred_type(x , " numeric" )
5961 new_data <-
6062 new_data %> %
6163 dplyr :: bind_cols(
@@ -68,13 +70,13 @@ augment.model_fit <- function(x, new_data, ...) {
6870 }
6971 }
7072 } else if (x $ spec $ mode == " classification" ) {
71- if (has_class_preds( x )) {
73+ if (spec_has_pred_type( x , " class " )) {
7274 new_data <- dplyr :: bind_cols(
7375 new_data ,
7476 predict(x , new_data = new_data , type = " class" )
7577 )
7678 }
77- if (has_class_probs( x )) {
79+ if (spec_has_pred_type( x , " prob " )) {
7880 new_data <- dplyr :: bind_cols(
7981 new_data ,
8082 predict(x , new_data = new_data , type = " prob" )
@@ -85,11 +87,3 @@ augment.model_fit <- function(x, new_data, ...) {
8587 }
8688 as_tibble(new_data )
8789}
88-
89- has_class_preds <- function (x ) {
90- any(names(x $ spec $ method $ pred ) == " class" )
91- }
92-
93- has_class_probs <- function (x ) {
94- any(names(x $ spec $ method $ pred ) == " prob" )
95- }
0 commit comments