Skip to content

Commit e3e8aaf

Browse files
committed
more fixes for tibble name repair
1 parent 52139aa commit e3e8aaf

File tree

8 files changed

+38
-8
lines changed

8 files changed

+38
-8
lines changed

R/aaa.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11

22
maybe_multivariate <- function(results, object) {
3+
34
if (isTRUE(ncol(results) > 1)) {
4-
results <- as_tibble(results)
5+
results <- as_tibble(results, .name_repair = "minimal")
56
} else {
67
results <- unname(results[, 1])
78
}

R/mlp.R

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,18 @@ keras_mlp <-
370370
fit_call <- rlang::call_modify(fit_call, !!!arg_values$fit)
371371

372372
history <- eval_tidy(fit_call)
373+
model$y_names <- colnames(y)
373374
model
374375
}
375376

377+
keras_numeric_post <- function(results, object) {
378+
if (ncol(results) > 1) {
379+
colnames(results) <- object$fit$y_names
380+
}
381+
maybe_multivariate(results, object)
382+
}
383+
384+
376385

377386
nnet_softmax <- function(results, object) {
378387
if (ncol(results) == 1)

R/mlp_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ set_pred(
8484
type = "numeric",
8585
value = list(
8686
pre = NULL,
87-
post = maybe_multivariate,
87+
post = keras_numeric_post,
8888
func = c(fun = "predict"),
8989
args =
9090
list(

R/surv_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,10 @@ survreg_quant <- function(results, object) {
174174
pctl <- object$spec$method$pred$quantile$args$p
175175
n <- nrow(results)
176176
p <- ncol(results)
177+
colnames(results) <- names0(p)
177178
results <-
178179
results %>%
179180
as_tibble() %>%
180-
setNames(names0(p)) %>%
181181
mutate(.row = 1:n) %>%
182182
gather(.label, .pred, -.row) %>%
183183
arrange(.row, .label) %>%

tests/testthat/test_aaaa.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
library(testthat)
66

7-
context("setting keras environment")
7+
context("setting keras environment\n")
88

99
Sys.setenv(TF_CPP_MIN_LOG_LEVEL = '3')
1010
try(keras:::backend(), silent = TRUE)

tests/testthat/test_linear_reg.R

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,15 @@ test_that('lm execution', {
253253
regexp = NA
254254
)
255255

256+
expect_error(
257+
res <- fit_xy(
258+
iris_basic,
259+
x = iris[, 1:2],
260+
y = iris[3:4],
261+
control = ctrl
262+
),
263+
regexp = NA
264+
)
256265
})
257266

258267
test_that('lm prediction', {

tests/testthat/test_mars.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,16 @@ test_that('mars execution', {
159159
),
160160
regexp = NA
161161
)
162+
163+
expect_error(
164+
res <- fit_xy(
165+
iris_basic,
166+
x = iris[, 1:2],
167+
y = iris[3:4],
168+
control = ctrl
169+
),
170+
regexp = NA
171+
)
162172
parsnip:::load_libs(res, attach = TRUE)
163173

164174
})

tests/testthat/test_mlp_keras.R

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ ctrl <- fit_control(verbosity = 1, catch = FALSE)
1616
caught_ctrl <- fit_control(verbosity = 1, catch = TRUE)
1717
quiet_ctrl <- fit_control(verbosity = 0, catch = TRUE)
1818

19+
nn_dat <- read.csv("nnet_test.txt")
20+
1921
# ------------------------------------------------------------------------------
2022

2123
test_that('keras execution, classification', {
@@ -102,8 +104,8 @@ test_that('keras classification probabilities', {
102104
)
103105

104106
xy_pred <- keras::predict_proba(xy_fit$fit, x = as.matrix(iris[1:8, num_pred]))
105-
xy_pred <- as_tibble(xy_pred)
106107
colnames(xy_pred) <- paste0(".pred_", levels(iris$Species))
108+
xy_pred <- as_tibble(xy_pred)
107109
expect_equal(xy_pred, predict(xy_fit, new_data = iris[1:8, num_pred], type = "prob"))
108110

109111
keras::backend()$clear_session()
@@ -116,8 +118,8 @@ test_that('keras classification probabilities', {
116118
)
117119

118120
form_pred <- keras::predict_proba(form_fit$fit, x = as.matrix(iris[1:8, num_pred]))
119-
form_pred <- as_tibble(form_pred)
120121
colnames(form_pred) <- paste0(".pred_", levels(iris$Species))
122+
form_pred <- as_tibble(form_pred)
121123
expect_equal(form_pred, predict(form_fit, new_data = iris[1:8, num_pred], type = "prob"))
122124

123125
keras::backend()$clear_session()
@@ -204,8 +206,6 @@ test_that('keras regression prediction', {
204206

205207
# ------------------------------------------------------------------------------
206208

207-
nn_dat <- read.csv("nnet_test.txt")
208-
209209
test_that('multivariate nnet formula', {
210210
skip_on_cran()
211211
skip_if_not_installed("keras")
@@ -218,6 +218,7 @@ test_that('multivariate nnet formula', {
218218
data = nn_dat[-(1:5),]
219219
)
220220
expect_equal(length(unlist(keras::get_weights(nnet_form$fit))), 24)
221+
221222
nnet_form_pred <- predict(nnet_form, new_data = nn_dat[1:5, -(1:3)])
222223
expect_equal(names(nnet_form_pred), paste0(".pred_", c("V1", "V2", "V3")))
223224

0 commit comments

Comments
 (0)