Skip to content

Commit 4385332

Browse files
committed
closes #89
1 parent fcace3f commit 4385332

13 files changed

+84
-0
lines changed

R/boost_tree.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,9 @@ xgb_pred <- function(object, newdata, ...) {
359359
#' @export
360360
multi_predict._xgb.Booster <-
361361
function(object, new_data, type = NULL, trees = NULL, ...) {
362+
if (any(names(enquos(...)) == "newdata"))
363+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
364+
362365
if (is.null(trees))
363366
trees <- object$fit$nIter
364367
trees <- sort(trees)
@@ -458,6 +461,9 @@ C5.0_train <-
458461
#' @export
459462
multi_predict._C5.0 <-
460463
function(object, new_data, type = NULL, trees = NULL, ...) {
464+
if (any(names(enquos(...)) == "newdata"))
465+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
466+
461467
if (is.null(trees))
462468
trees <- min(object$fit$trials)
463469
trees <- sort(trees)

R/linear_reg.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,18 +211,27 @@ organize_glmnet_pred <- function(x, object) {
211211
#' @export
212212
predict._elnet <-
213213
function(object, new_data, type = NULL, opts = list(), ...) {
214+
if (any(names(enquos(...)) == "newdata"))
215+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
216+
214217
object$spec <- eval_args(object$spec)
215218
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
216219
}
217220

218221
#' @export
219222
predict_num._elnet <- function(object, new_data, ...) {
223+
if (any(names(enquos(...)) == "newdata"))
224+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
225+
220226
object$spec <- eval_args(object$spec)
221227
predict_num.model_fit(object, new_data = new_data, ...)
222228
}
223229

224230
#' @export
225231
predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
232+
if (any(names(enquos(...)) == "newdata"))
233+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
234+
226235
object$spec <- eval_args(object$spec)
227236
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
228237
}
@@ -232,6 +241,9 @@ predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
232241
#' @export
233242
multi_predict._elnet <-
234243
function(object, new_data, type = NULL, penalty = NULL, ...) {
244+
if (any(names(enquos(...)) == "newdata"))
245+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
246+
235247
dots <- list(...)
236248
if (is.null(penalty))
237249
penalty <- object$fit$lambda

R/logistic_reg.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,24 +230,36 @@ organize_glmnet_prob <- function(x, object) {
230230

231231
#' @export
232232
predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) {
233+
if (any(names(enquos(...)) == "newdata"))
234+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
235+
233236
object$spec <- eval_args(object$spec)
234237
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
235238
}
236239

237240
#' @export
238241
predict_class._lognet <- function (object, new_data, ...) {
242+
if (any(names(enquos(...)) == "newdata"))
243+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
244+
239245
object$spec <- eval_args(object$spec)
240246
predict_class.model_fit(object, new_data = new_data, ...)
241247
}
242248

243249
#' @export
244250
predict_classprob._lognet <- function (object, new_data, ...) {
251+
if (any(names(enquos(...)) == "newdata"))
252+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
253+
245254
object$spec <- eval_args(object$spec)
246255
predict_classprob.model_fit(object, new_data = new_data, ...)
247256
}
248257

249258
#' @export
250259
predict_raw._lognet <- function (object, new_data, opts = list(), ...) {
260+
if (any(names(enquos(...)) == "newdata"))
261+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
262+
251263
object$spec <- eval_args(object$spec)
252264
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
253265
}
@@ -258,6 +270,9 @@ predict_raw._lognet <- function (object, new_data, opts = list(), ...) {
258270
#' @export
259271
multi_predict._lognet <-
260272
function(object, new_data, type = NULL, penalty = NULL, ...) {
273+
if (any(names(enquos(...)) == "newdata"))
274+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
275+
261276
dots <- list(...)
262277
if (is.null(penalty))
263278
penalty <- object$lambda

R/mars.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ earth_reg_updater <- function(num, object, new_data, ...) {
206206
#' @export
207207
multi_predict._earth <-
208208
function(object, new_data, type = NULL, num_terms = NULL, ...) {
209+
if (any(names(enquos(...)) == "newdata"))
210+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
211+
209212
if (is.null(num_terms))
210213
num_terms <- object$fit$selected.terms[-1]
211214

R/multinom_reg.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,9 @@ predict._multnet <-
236236
#' @export
237237
multi_predict._multnet <-
238238
function(object, new_data, type = NULL, penalty = NULL, ...) {
239+
if (any(names(enquos(...)) == "newdata"))
240+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
241+
239242
if (is_quosure(penalty))
240243
penalty <- eval_tidy(penalty)
241244

R/predict.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@
9191
#' @export predict.model_fit
9292
#' @export
9393
predict.model_fit <- function (object, new_data, type = NULL, opts = list(), ...) {
94+
if (any(names(enquos(...)) == "newdata"))
95+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
96+
9497
type <- check_pred_type(object, type)
9598
if (type != "raw" && length(opts) > 0)
9699
warning("`opts` is only used with `type = 'raw'` and was ignored.")

tests/testthat/test_boost_tree_C50.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,5 +121,10 @@ test_that('submodel prediction', {
121121
mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], trees = 4, type = "prob")
122122
mp_res <- do.call("rbind", mp_res$.pred)
123123
expect_equal(mp_res[[".pred_No"]], unname(pred_class[, "No"]))
124+
125+
expect_error(
126+
multi_predict(class_fit, newdata = wa_churn[1:4, vars], trees = 4, type = "prob"),
127+
"Did you mean"
128+
)
124129
})
125130

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,5 +188,10 @@ test_that('submodel prediction', {
188188
mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], trees = 5, type = "prob")
189189
mp_res <- do.call("rbind", mp_res$.pred)
190190
expect_equal(mp_res[[".pred_No"]], pred_class)
191+
192+
expect_error(
193+
multi_predict(class_fit, newdata = wa_churn[1:4, vars], trees = 5, type = "prob"),
194+
"Did you mean"
195+
)
191196
})
192197

tests/testthat/test_linear_reg.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,14 @@ test_that('lm intervals', {
322322
expect_equivalent(prediction_parsnip$.pred_upper, prediction_lm[, "upr"])
323323
})
324324

325+
326+
test_that('newdata error trapping', {
327+
res_xy <- fit_xy(
328+
iris_basic,
329+
x = iris[, num_pred],
330+
y = iris$Sepal.Length,
331+
control = ctrl
332+
)
333+
expect_error(predict(res_xy, newdata = iris[1:3, num_pred]), "Did you mean")
334+
})
335+

tests/testthat/test_linear_reg_glmnet.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,5 +198,10 @@ test_that('submodel prediction', {
198198
mp_res <- multi_predict(reg_fit, new_data = mtcars[1:4, -1], penalty = .1)
199199
mp_res <- do.call("rbind", mp_res$.pred)
200200
expect_equal(mp_res[[".pred"]], unname(pred_glmn[,1]))
201+
202+
expect_error(
203+
multi_predict(reg_fit, newdata = mtcars[1:4, -1], penalty = .1),
204+
"Did you mean"
205+
)
201206
})
202207

0 commit comments

Comments
 (0)