1313# ' @param ... Not currently used.
1414# ' @export
1515# ' @examples
16+ # ' car_trn <- mtcars[11:32,]
17+ # ' car_tst <- mtcars[ 1:10,]
18+ # '
1619# ' reg_form <-
1720# ' linear_reg() %>%
1821# ' set_engine("lm") %>%
19- # ' fit(mpg ~ ., data = mtcars )
22+ # ' fit(mpg ~ ., data = car_trn )
2023# ' reg_xy <-
2124# ' linear_reg() %>%
2225# ' set_engine("lm") %>%
23- # ' fit_xy(mtcars [, -1], mtcars $mpg)
26+ # ' fit_xy(car_trn [, -1], car_trn $mpg)
2427# '
25- # ' augment(reg_form, head(mtcars) )
26- # ' augment(reg_form, head(mtcars [, -1]) )
28+ # ' augment(reg_form, car_tst )
29+ # ' augment(reg_form, car_tst [, -1])
2730# '
28- # ' augment(reg_xy, head(mtcars) )
29- # ' augment(reg_xy, head(mtcars [, -1]) )
31+ # ' augment(reg_xy, car_tst )
32+ # ' augment(reg_xy, car_tst [, -1])
3033# '
3134# ' # ------------------------------------------------------------------------------
3235# '
3336# ' data(two_class_dat, package = "modeldata")
37+ # ' cls_trn <- two_class_dat[-(1:10), ]
38+ # ' cls_tst <- two_class_dat[ 1:10 , ]
3439# '
3540# ' cls_form <-
3641# ' logistic_reg() %>%
3742# ' set_engine("glm") %>%
38- # ' fit(Class ~ ., data = two_class_dat )
43+ # ' fit(Class ~ ., data = cls_trn )
3944# ' cls_xy <-
4045# ' logistic_reg() %>%
4146# ' set_engine("glm") %>%
42- # ' fit_xy(two_class_dat [, -3],
43- # ' two_class_dat $Class)
47+ # ' fit_xy(cls_trn [, -3],
48+ # ' cls_trn $Class)
4449# '
45- # ' augment(cls_form, head(two_class_dat) )
46- # ' augment(cls_form, head(two_class_dat [, -3]) )
50+ # ' augment(cls_form, cls_tst )
51+ # ' augment(cls_form, cls_tst [, -3])
4752# '
48- # ' augment(cls_xy, head(two_class_dat) )
49- # ' augment(cls_xy, head(two_class_dat [, -3]) )
53+ # ' augment(cls_xy, cls_tst )
54+ # ' augment(cls_xy, cls_tst [, -3])
5055# '
5156augment.model_fit <- function (x , new_data , ... ) {
5257 if (x $ spec $ mode == " regression" ) {
@@ -61,13 +66,15 @@ augment.model_fit <- function(x, new_data, ...) {
6166 new_data <- dplyr :: mutate(new_data , .resid = !! rlang :: sym(y_nm ) - .pred )
6267 }
6368 }
64- } else {
69+ } else if ( x $ spec $ mode == " classification " ) {
6570 new_data <-
6671 new_data %> %
6772 dplyr :: bind_cols(
6873 predict(x , new_data = new_data , type = " class" ),
6974 predict(x , new_data = new_data , type = " prob" )
7075 )
76+ } else {
77+ rlang :: abort(paste(" Unknown mode:" , x $ spec $ mode ))
7178 }
7279 new_data
7380}
0 commit comments