Skip to content

Commit e2215bc

Browse files
committed
fixed mars multipredict bug
1 parent a8f3c83 commit e2215bc

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

R/mars.R

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,20 @@ multi_predict._earth <-
211211

212212
num_terms <- sort(num_terms)
213213

214+
# update.earth uses the values in the call so evaluate them if
215+
# they are quosures
216+
call_names <- names(object$fit$call)
217+
call_names <- call_names[!(call_names %in% c("", "x", "y"))]
218+
for (i in call_names) {
219+
if (is_quosure(object$fit$call[[i]]))
220+
object$fit$call[[i]] <- eval_tidy(object$fit$call[[i]])
221+
}
222+
214223
msg <-
215224
paste("Please use `keepxy = TRUE` as an option to enable submodel",
216225
"predictions with `earth`.")
217-
if (any(names(object$spec$eng_args) == "keepxy")) {
218-
if(!object$spec$eng_args$keepxy)
226+
if (any(names(object$fit$call) == "keepxy")) {
227+
if(!isTRUE(object$fit$call$keepxy))
219228
stop (msg, call. = FALSE)
220229
} else
221230
stop (msg, call. = FALSE)

tests/testthat/test_mars.R

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
library(testthat)
2-
32
library(parsnip)
43
library(rlang)
54

@@ -202,7 +201,6 @@ test_that('mars prediction', {
202201

203202

204203
test_that('submodel prediction', {
205-
skip("need fit$call object to have real values (not quosures)")
206204
skip_if_not_installed("earth")
207205
library(earth)
208206

@@ -215,12 +213,17 @@ test_that('submodel prediction', {
215213
set_engine("earth", keepxy = TRUE) %>%
216214
fit(mpg ~ ., data = mtcars[-(1:4), ])
217215

218-
pruned_fit <- update(reg_fit$fit, nprune = 5)
219-
pruned_pred <- predict(pruned_fit, mtcars[1:4, -1])[,1]
216+
tmp_reg <- reg_fit$fit
217+
tmp_reg$call[["pmethod"]] <- eval_tidy(tmp_reg$call[["pmethod"]])
218+
tmp_reg$call[["keepxy"]] <- eval_tidy(tmp_reg$call[["keepxy"]])
219+
tmp_reg$call[["nprune"]] <- eval_tidy(tmp_reg$call[["nprune"]])
220+
221+
pruned_reg <- update(tmp_reg, nprune = 5)
222+
pruned_reg_pred <- predict(pruned_reg, mtcars[1:4, -1])[,1]
220223

221224
mp_res <- multi_predict(reg_fit, new_data = mtcars[1:4, -1], num_terms = 5)
222225
mp_res <- do.call("rbind", mp_res$.pred)
223-
expect_equal(mp_res[[".pred"]], pruned_pred)
226+
expect_equal(mp_res[[".pred"]], pruned_reg_pred)
224227

225228
vars <- c("female", "tenure", "total_charges", "phone_service", "monthly_charges")
226229
class_fit <-
@@ -229,12 +232,16 @@ test_that('submodel prediction', {
229232
fit(churn ~ .,
230233
data = wa_churn[-(1:4), c("churn", vars)])
231234

232-
pruned_fit <- update(class_fit$fit, nprune = 5)
233-
pruned_pred <- predict(pruned_fit, wa_churn[1:4, vars], type = "response")[,1]
235+
cls_fit <- class_fit$fit
236+
cls_fit$call[["pmethod"]] <- eval_tidy(cls_fit$call[["pmethod"]])
237+
cls_fit$call[["keepxy"]] <- eval_tidy(cls_fit$call[["keepxy"]])
238+
239+
pruned_cls <- update(cls_fit, nprune = 5)
240+
pruned_cls_pred <- predict(pruned_cls, wa_churn[1:4, vars], type = "response")[,1]
234241

235242
mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], num_terms = 5, type = "prob")
236243
mp_res <- do.call("rbind", mp_res$.pred)
237-
expect_equal(mp_res[[".pred_No"]], pruned_pred)
244+
expect_equal(mp_res[[".pred_No"]], pruned_cls_pred)
238245
})
239246

240247

0 commit comments

Comments
 (0)