Skip to content

Commit 319050f

Browse files
committed
framework for multivariate y predictions
1 parent 7b36906 commit 319050f

File tree

7 files changed

+61
-14
lines changed

7 files changed

+61
-14
lines changed

R/aaa.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
maybe_multivariate <- function(results, object) {
33

44
if (isTRUE(ncol(results) > 1)) {
5+
nms <- colnames(results)
56
results <- as_tibble(results, .name_repair = "minimal")
7+
if (length(nms) == 0 && length(object$preproc$y_var) == ncol(results)) {
8+
names(results) <- object$preproc$y_var
9+
}
610
} else {
711
results <- unname(results[, 1])
812
}

R/fit_helpers.R

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ form_form <-
5959
env = env,
6060
...
6161
)
62-
res$preproc <- NA
62+
res$preproc <- list(y_var = all.vars(env$formula[[2]]))
6363
res
6464
}
6565

@@ -114,7 +114,12 @@ xy_xy <- function(object, env, control, target = "none", ...) {
114114
env = env,
115115
...
116116
)
117-
res$preproc <- NA
117+
if (is.vector(env$y)) {
118+
y_name <- character(0)
119+
} else {
120+
y_name <- colnames(env$y)
121+
}
122+
res$preproc <- list(y_var = y_name)
118123
res
119124
}
120125

@@ -144,6 +149,7 @@ form_xy <- function(object, control, env,
144149
control = control,
145150
target = target
146151
)
152+
data_obj$y_var <- all.vars(env$formula[[2]])
147153
data_obj$x <- NULL
148154
data_obj$y <- NULL
149155
data_obj$weights <- NULL
@@ -177,7 +183,12 @@ xy_form <- function(object, env, control, ...) {
177183
control = control,
178184
...
179185
)
180-
res$preproc <- data_obj["x_var"]
186+
if (is.vector(env$y)) {
187+
data_obj$y_var <- character(0)
188+
} else {
189+
data_obj$y_var <- colnames(env$y)
190+
}
191+
res$preproc <- data_obj[c("x_var", "y_var")]
181192
res
182193
}
183194

R/misc.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,15 @@ check_outcome <- function(y, spec) {
223223
invisible(NULL)
224224
}
225225

226+
227+
# Get's a character string of varible names used as the outcome
228+
# in a terms object
229+
terms_y <- function(x) {
230+
att <- attributes(x)
231+
resp_ind <- att$response
232+
y_expr <- att$predvars[[resp_ind + 1]]
233+
all.vars(y_expr)
234+
}
235+
236+
237+

R/mlp.R

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -374,15 +374,6 @@ keras_mlp <-
374374
model
375375
}
376376

377-
keras_numeric_post <- function(results, object) {
378-
if (ncol(results) > 1 && !is.null(object$fit$y_names)) {
379-
print(object$fit$y_names)
380-
colnames(results) <- object$fit$y_names
381-
}
382-
maybe_multivariate(results, object)
383-
}
384-
385-
386377

387378
nnet_softmax <- function(results, object) {
388379
if (ncol(results) == 1)

R/mlp_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ set_pred(
8484
type = "numeric",
8585
value = list(
8686
pre = NULL,
87-
post = keras_numeric_post,
87+
post = maybe_multivariate,
8888
func = c(fun = "predict"),
8989
args =
9090
list(

R/predict.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,12 @@ make_pred_call <- function(x) {
216216
cl
217217
}
218218

219+
219220
prepare_data <- function(object, new_data) {
220221
fit_interface <- object$spec$method$fit$interface
221222

222-
if (!all(is.na(object$preproc))) {
223+
pp_names <- names(object$preproc)
224+
if (any(pp_names == "terms") | any(pp_names == "x_var")) {
223225
# Translation code
224226
if (fit_interface == "formula") {
225227
new_data <- convert_xy_to_form_new(object$preproc, new_data)

tests/testthat/test_misc.R

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,30 @@ test_that('other objects', {
2929

3030
# ------------------------------------------------------------------------------
3131

32+
context("getting y names from terms")
33+
34+
test_that('getting y names from terms', {
35+
36+
expect_equal(
37+
parsnip:::terms_y(lm(cbind(mpg, disp) ~., data = mtcars)$terms),
38+
c("mpg", "disp")
39+
)
40+
41+
expect_equal(
42+
parsnip:::terms_y(lm(mpg ~., data = mtcars)$terms),
43+
"mpg"
44+
)
45+
46+
expect_equal(
47+
parsnip:::terms_y(lm(log(mpg) ~., data = mtcars)$terms),
48+
"mpg"
49+
)
50+
51+
expect_equal(
52+
parsnip:::terms_y(terms( ~., data = mtcars)),
53+
character(0)
54+
)
55+
56+
57+
})
58+

0 commit comments

Comments
 (0)