Skip to content

Commit a5e9037

Browse files
authored
Merge pull request #313 from tidymodels/inner-names-predict
Remove inner names from prediction output
2 parents 2403a69 + 01aecfe commit a5e9037

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

R/predict.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ format_num <- function(x) {
191191
names(x) <- paste0(".pred_", names(x))
192192
}
193193
} else {
194-
x <- tibble(.pred = x)
194+
x <- tibble(.pred = unname(x))
195195
}
196196

197197
x
@@ -201,14 +201,15 @@ format_class <- function(x) {
201201
if (inherits(x, "tbl_spark"))
202202
return(x)
203203

204-
tibble(.pred_class = x)
204+
tibble(.pred_class = unname(x))
205205
}
206206

207207
format_classprobs <- function(x) {
208208
if (!any(grepl("^\\.pred_", names(x)))) {
209209
names(x) <- paste0(".pred_", names(x))
210210
}
211211
x <- as_tibble(x)
212+
x <- purrr::map_dfr(x, rlang::set_names, NULL)
212213
x
213214
}
214215

tests/testthat/test_logistic_reg.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,9 @@ test_that('glm probabilities', {
300300
control = ctrl
301301
)
302302

303-
xy_pred <- predict(classes_xy$fit, newdata = lending_club[1:7, num_pred], type = "response")
303+
xy_pred <- unname(predict(classes_xy$fit,
304+
newdata = lending_club[1:7, num_pred],
305+
type = "response"))
304306
xy_pred <- tibble(.pred_bad = 1 - xy_pred, .pred_good = xy_pred)
305307
expect_equal(xy_pred, predict(classes_xy, lending_club[1:7, num_pred], type = "prob"))
306308

tests/testthat/test_logistic_reg_glmnet.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,9 @@ test_that('glmnet probabilities, one lambda', {
251251
form_mat <- form_mat[1:7, -1]
252252

253253
form_pred <-
254-
predict(res_form$fit,
254+
unname(predict(res_form$fit,
255255
newx = form_mat,
256-
s = 0.1, type = "response")[, 1]
256+
s = 0.1, type = "response")[, 1])
257257
form_pred <- tibble(.pred_bad = 1 - form_pred, .pred_good = form_pred)
258258

259259
expect_equal(
@@ -358,7 +358,8 @@ test_that('glmnet probabilities, no lambda', {
358358

359359
expect_equal(
360360
mult_pred,
361-
multi_predict(xy_fit, lending_club[1:7, num_pred], type = "prob") %>% unnest()
361+
multi_predict(xy_fit, lending_club[1:7, num_pred], type = "prob") %>%
362+
unnest(cols = c(.pred))
362363
)
363364

364365
res_form <- fit(

0 commit comments

Comments
 (0)