11library(testthat )
2-
32library(parsnip )
43library(rlang )
54
@@ -202,7 +201,6 @@ test_that('mars prediction', {
202201
203202
204203test_that(' submodel prediction' , {
205- skip(" need fit$call object to have real values (not quosures)" )
206204 skip_if_not_installed(" earth" )
207205 library(earth )
208206
@@ -215,12 +213,17 @@ test_that('submodel prediction', {
215213 set_engine(" earth" , keepxy = TRUE ) %> %
216214 fit(mpg ~ . , data = mtcars [- (1 : 4 ), ])
217215
218- pruned_fit <- update(reg_fit $ fit , nprune = 5 )
219- pruned_pred <- predict(pruned_fit , mtcars [1 : 4 , - 1 ])[,1 ]
216+ tmp_reg <- reg_fit $ fit
217+ tmp_reg $ call [[" pmethod" ]] <- eval_tidy(tmp_reg $ call [[" pmethod" ]])
218+ tmp_reg $ call [[" keepxy" ]] <- eval_tidy(tmp_reg $ call [[" keepxy" ]])
219+ tmp_reg $ call [[" nprune" ]] <- eval_tidy(tmp_reg $ call [[" nprune" ]])
220+
221+ pruned_reg <- update(tmp_reg , nprune = 5 )
222+ pruned_reg_pred <- predict(pruned_reg , mtcars [1 : 4 , - 1 ])[,1 ]
220223
221224 mp_res <- multi_predict(reg_fit , new_data = mtcars [1 : 4 , - 1 ], num_terms = 5 )
222225 mp_res <- do.call(" rbind" , mp_res $ .pred )
223- expect_equal(mp_res [[" .pred" ]], pruned_pred )
226+ expect_equal(mp_res [[" .pred" ]], pruned_reg_pred )
224227
225228 vars <- c(" female" , " tenure" , " total_charges" , " phone_service" , " monthly_charges" )
226229 class_fit <-
@@ -229,12 +232,16 @@ test_that('submodel prediction', {
229232 fit(churn ~ . ,
230233 data = wa_churn [- (1 : 4 ), c(" churn" , vars )])
231234
232- pruned_fit <- update(class_fit $ fit , nprune = 5 )
233- pruned_pred <- predict(pruned_fit , wa_churn [1 : 4 , vars ], type = " response" )[,1 ]
235+ cls_fit <- class_fit $ fit
236+ cls_fit $ call [[" pmethod" ]] <- eval_tidy(cls_fit $ call [[" pmethod" ]])
237+ cls_fit $ call [[" keepxy" ]] <- eval_tidy(cls_fit $ call [[" keepxy" ]])
238+
239+ pruned_cls <- update(cls_fit , nprune = 5 )
240+ pruned_cls_pred <- predict(pruned_cls , wa_churn [1 : 4 , vars ], type = " response" )[,1 ]
234241
235242 mp_res <- multi_predict(class_fit , new_data = wa_churn [1 : 4 , vars ], num_terms = 5 , type = " prob" )
236243 mp_res <- do.call(" rbind" , mp_res $ .pred )
237- expect_equal(mp_res [[" .pred_No" ]], pruned_pred )
244+ expect_equal(mp_res [[" .pred_No" ]], pruned_cls_pred )
238245})
239246
240247
0 commit comments