Skip to content

Commit d8c7484

Browse files
committed
un-do return values for failed models for issue #123
1 parent 9f7a25a commit d8c7484

File tree

12 files changed

+85
-133
lines changed

12 files changed

+85
-133
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ S3method(predict,model_fit)
1717
S3method(predict,model_spec)
1818
S3method(predict,nullmodel)
1919
S3method(predict_class,"_lognet")
20+
S3method(predict_class,"_multnet")
2021
S3method(predict_classprob,"_lognet")
2122
S3method(predict_classprob,"_multnet")
2223
S3method(predict_numeric,"_elnet")

R/predict.R

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
112112
if (any(names(the_dots) == "newdata"))
113113
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
114114

115+
if (inherits(object$fit, "try-error")) {
116+
warning("Model fit failed; cannot make predictions.", call. = FALSE)
117+
return(NULL)
118+
}
119+
115120
other_args <- c("level", "std_error", "quantile") # "time" for survival probs later
116121
is_pred_arg <- names(the_dots) %in% other_args
117122
if (any(!is_pred_arg)) {
@@ -242,8 +247,13 @@ prepare_data <- function(object, new_data) {
242247
#' multiple rows per sub-model.
243248
#' @keywords internal
244249
#' @export
245-
multi_predict <- function(object, ...)
250+
multi_predict <- function(object, ...) {
251+
if (inherits(object$fit, "try-error")) {
252+
warning("Model fit failed; cannot make predictions.", call. = FALSE)
253+
return(NULL)
254+
}
246255
UseMethod("multi_predict")
256+
}
247257

248258
#' @keywords internal
249259
#' @export
@@ -256,11 +266,3 @@ multi_predict.default <- function(object, ...)
256266
predict.model_spec <- function(object, ...) {
257267
stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE)
258268
}
259-
260-
261-
failed_class <- function(n, lvl) {
262-
tibble(.pred = rep(NA_real_, n))
263-
}
264-
265-
266-

R/predict_class.R

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ predict_class.model_fit <- function(object, new_data, ...) {
1717
stop("No class prediction module defined for this model.", call. = FALSE)
1818

1919
if (inherits(object$fit, "try-error")) {
20-
return(failed_class(lvl = object$lvl))
20+
warning("Model fit failed; cannot make predictions.", call. = FALSE)
21+
return(NULL)
2122
}
2223

2324
new_data <- prepare_data(object, new_data)
@@ -54,13 +55,3 @@ predict_class.model_fit <- function(object, new_data, ...) {
5455
predict_class <- function(object, ...)
5556
UseMethod("predict_class")
5657

57-
# ------------------------------------------------------------------------------
58-
59-
# Some `predict()` helpers for failed models:
60-
61-
failed_class <- function(n = 1, lvl) {
62-
res <- rep(NA_character_, n)
63-
res <- factor(res, levels = lvl)
64-
res
65-
}
66-

R/predict_classprob.R

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
1414
stop("No class probability module defined for this model.", call. = FALSE)
1515

1616
if (inherits(object$fit, "try-error")) {
17-
return(failed_classprob(lvl = object$lvl))
17+
warning("Model fit failed; cannot make predictions.", call. = FALSE)
18+
return(NULL)
1819
}
1920

2021
new_data <- prepare_data(object, new_data)
@@ -49,15 +50,3 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
4950
# @inheritParams predict.model_fit
5051
predict_classprob <- function(object, ...)
5152
UseMethod("predict_classprob")
52-
53-
54-
# ------------------------------------------------------------------------------
55-
56-
# Some `predict()` helpers for failed models:
57-
58-
failed_classprob <- function(n = 1, lvl) {
59-
res <- matrix(NA_real_, nrow = n, ncol = length(lvl))
60-
colnames(res) <- lvl
61-
as_tibble(res)
62-
}
63-

R/predict_interval.R

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error
1515
"engine.", call. = FALSE)
1616

1717
if (inherits(object$fit, "try-error")) {
18-
return(failed_int(lvl = object$lvl))
18+
warning("Model fit failed; cannot make predictions.", call. = FALSE)
19+
return(NULL)
1920
}
2021

2122
new_data <- prepare_data(object, new_data)
@@ -50,24 +51,6 @@ predict_confint <- function(object, ...)
5051

5152
# ------------------------------------------------------------------------------
5253

53-
# Some `predict()` helpers for failed models:
54-
55-
failed_int <- function(n = 1, lvl = NULL, nms = ".pred") {
56-
# TODO figure out multivariate models
57-
if (is.null(lvl)) {
58-
res <- matrix(NA_real_, nrow = n, ncol = length(nms) * 2)
59-
colnames(res) <- c(".pred_lower", ".pred_upper")
60-
} else {
61-
res <- matrix(NA_real_, ncol = length(lvl) * 2, nrow = n)
62-
nms <- expand.grid(c("lower", "upper"), lvl)
63-
nms <- paste(".pred", nms$Var1, nms$Var2, sep = "_")
64-
colnames(res) <- nms
65-
}
66-
as_tibble(res)
67-
}
68-
69-
# ------------------------------------------------------------------------------
70-
7154
# @keywords internal
7255
# @rdname other_predict
7356
# @inheritParams predict.model_fit
@@ -81,7 +64,8 @@ predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error
8164
"engine.", call. = FALSE)
8265

8366
if (inherits(object$fit, "try-error")) {
84-
return(failed_int(lvl = object$lvl))
67+
warning("Model fit failed; cannot make predictions.", call. = FALSE)
68+
return(NULL)
8569
}
8670

8771
new_data <- prepare_data(object, new_data)

R/predict_numeric.R

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ predict_numeric.model_fit <- function(object, new_data, ...) {
1515
stop("No prediction module defined for this model.", call. = FALSE)
1616

1717
if (inherits(object$fit, "try-error")) {
18-
# TODO handle multivariate cases
19-
return(failed_numeric())
18+
warning("Model fit failed; cannot make predictions.", call. = FALSE)
19+
return(NULL)
2020
}
2121

2222
new_data <- prepare_data(object, new_data)
@@ -51,21 +51,3 @@ predict_numeric.model_fit <- function(object, new_data, ...) {
5151
# @inheritParams predict_numeric.model_fit
5252
predict_numeric <- function(object, ...)
5353
UseMethod("predict_numeric")
54-
55-
# ------------------------------------------------------------------------------
56-
57-
# Some `predict()` helpers for failed models:
58-
59-
failed_numeric <- function(n = 1, nms = ".pred") {
60-
res <- matrix(NA_real_, ncol = length(nms), nrow = n)
61-
if (length(nms) > 1) {
62-
colnames(res) <- nms
63-
res <- as_tibble(res)
64-
} else {
65-
res <- res[,1]
66-
}
67-
res
68-
}
69-
70-
71-

R/predict_quantile.R

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,40 @@
11
# @keywords internal
22
# @rdname other_predict
3-
# @param quant A vector of numbers between 0 and 1 for the quantile being
4-
# predicted.
3+
# @param quant A vector of numbers between 0 and 1 for the quantile being
4+
# predicted.
55
# @inheritParams predict.model_fit
66
# @method predict_quantile model_fit
77
# @export predict_quantile.model_fit
88
# @export
99
predict_quantile.model_fit <-
1010
function (object, new_data, quantile = (1:9)/10, ...) {
11-
11+
1212
if (is.null(object$spec$method$quantile))
1313
stop("No quantile prediction method defined for this ",
1414
"engine.", call. = FALSE)
15-
15+
16+
if (inherits(object$fit, "try-error")) {
17+
warning("Model fit failed; cannot make predictions.", call. = FALSE)
18+
return(NULL)
19+
}
20+
1621
new_data <- prepare_data(object, new_data)
17-
22+
1823
# preprocess data
1924
if (!is.null(object$spec$method$quantile$pre))
2025
new_data <- object$spec$method$quantile$pre(new_data, object)
21-
26+
2227
# Pass some extra arguments to be used in post-processor
2328
object$spec$method$quantile$args$p <- quantile
2429
pred_call <- make_pred_call(object$spec$method$quantile)
25-
30+
2631
res <- eval_tidy(pred_call)
27-
32+
2833
# post-process the predictions
2934
if(!is.null(object$spec$method$quantile$post)) {
3035
res <- object$spec$method$quantile$post(res, object)
3136
}
32-
37+
3338
res
3439
}
3540

R/predict_raw.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ predict_raw.model_fit <- function(object, new_data, opts = list(), ...) {
1818
stop("No raw prediction module defined for this model.", call. = FALSE)
1919

2020
if (inherits(object$fit, "try-error")) {
21-
stop("Model fit failed; cannot make predictions.")
21+
warning("Model fit failed; cannot make predictions.", call. = FALSE)
22+
return(NULL)
2223
}
2324

2425
new_data <- prepare_data(object, new_data)

man/linear_reg.Rd

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/logistic_reg.Rd

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)