Skip to content

Commit 0d0dd4c

Browse files
authored
Fix augment bug 🐛 related to bind_cols() (#510)
* Fix augment bind_cols bug * Update NEWS
1 parent af3e0ae commit 0d0dd4c

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

NEWS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
* The helper functions `.convert_form_to_xy_fit()`, `.convert_form_to_xy_new()`, `.convert_xy_to_form_fit()`, and `.convert_xy_to_form_new()` for converting between formula and matrix interface are now exported for developer use (#508).
44

5+
* Fix bug in `augment()` when non-predictor, non-outcome variables are included in data (#510).
6+
57
# parsnip 0.1.6
68

79
## Model Specification Changes
@@ -19,7 +21,6 @@
1921

2022
* For xgboost, `mtry` and `colsample_bytree` can be passed as integer counts or proportions, while `subsample` and `validation` should always be proportions. `xgb_train()` now has a new option `counts` (`TRUE` or `FALSE`) that states which scale for `mtry` and `colsample_bytree` is being used. (#461)
2123

22-
r
2324
## Other Changes
2425

2526
* Re-licensed package from GPL-2 to MIT. See [consent from copyright holders here](https://github.com/tidymodels/parsnip/issues/462).

R/augment.R

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,34 +56,35 @@
5656
#' augment(cls_xy, cls_tst[, -3])
5757
#'
5858
augment.model_fit <- function(x, new_data, ...) {
59+
ret <- new_data
5960
if (x$spec$mode == "regression") {
6061
check_spec_pred_type(x, "numeric")
61-
new_data <-
62-
new_data %>%
62+
ret <-
63+
ret %>%
6364
dplyr::bind_cols(
6465
predict(x, new_data = new_data)
6566
)
6667
if (length(x$preproc$y_var) > 0) {
6768
y_nm <- x$preproc$y_var
6869
if (any(names(new_data) == y_nm)) {
69-
new_data <- dplyr::mutate(new_data, .resid = !!rlang::sym(y_nm) - .pred)
70+
ret <- dplyr::mutate(ret, .resid = !!rlang::sym(y_nm) - .pred)
7071
}
7172
}
7273
} else if (x$spec$mode == "classification") {
7374
if (spec_has_pred_type(x, "class")) {
74-
new_data <- dplyr::bind_cols(
75-
new_data,
75+
ret <- dplyr::bind_cols(
76+
ret,
7677
predict(x, new_data = new_data, type = "class")
7778
)
7879
}
7980
if (spec_has_pred_type(x, "prob")) {
80-
new_data <- dplyr::bind_cols(
81-
new_data,
81+
ret <- dplyr::bind_cols(
82+
ret,
8283
predict(x, new_data = new_data, type = "prob")
8384
)
8485
}
8586
} else {
8687
rlang::abort(paste("Unknown mode:", x$spec$mode))
8788
}
88-
as_tibble(new_data)
89+
as_tibble(ret)
8990
}

0 commit comments

Comments
 (0)