Skip to content

Commit c6a687e

Browse files
authored
Merge pull request #157 from tidymodels/unexport-predict-funcs
Various prediction-related changes
2 parents 5f746aa + 07e77a9 commit c6a687e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+755
-479
lines changed

NAMESPACE

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,13 @@ S3method(predict,model_fit)
1717
S3method(predict,model_spec)
1818
S3method(predict,nullmodel)
1919
S3method(predict_class,"_lognet")
20-
S3method(predict_class,model_fit)
20+
S3method(predict_class,"_multnet")
2121
S3method(predict_classprob,"_lognet")
2222
S3method(predict_classprob,"_multnet")
23-
S3method(predict_classprob,model_fit)
24-
S3method(predict_confint,model_fit)
2523
S3method(predict_numeric,"_elnet")
26-
S3method(predict_numeric,model_fit)
27-
S3method(predict_predint,model_fit)
28-
S3method(predict_quantile,model_fit)
2924
S3method(predict_raw,"_elnet")
3025
S3method(predict_raw,"_lognet")
3126
S3method(predict_raw,"_multnet")
32-
S3method(predict_raw,model_fit)
3327
S3method(print,boost_tree)
3428
S3method(print,decision_tree)
3529
S3method(print,fit_control)
@@ -103,20 +97,6 @@ export(nearest_neighbor)
10397
export(null_model)
10498
export(nullmodel)
10599
export(predict.model_fit)
106-
export(predict_class)
107-
export(predict_class.model_fit)
108-
export(predict_classprob)
109-
export(predict_classprob.model_fit)
110-
export(predict_confint)
111-
export(predict_confint.model_fit)
112-
export(predict_numeric)
113-
export(predict_numeric.model_fit)
114-
export(predict_predint)
115-
export(predict_predint.model_fit)
116-
export(predict_quantile)
117-
export(predict_quantile.model_fit)
118-
export(predict_raw)
119-
export(predict_raw.model_fit)
120100
export(rand_forest)
121101
export(rpart_train)
122102
export(set_args)

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## New Features
44

55
* A "null model" is now available that fits a predictor-free model (using the mean of the outcome for regression or the mode for classification).
6+
67
* `fit_xy()` can take a single column data frame or matrix for `y` without error
78

89
## Other Changes
@@ -13,6 +14,8 @@ that are actually varying).
1314

1415
* `fit_control()` not returns an S3 method.
1516

17+
* The prediction modules (e.g. `predict_class`, `predict_numeric`, etc) were de-exported. These were internal functions that were not to be used by the users and the users were using them.
18+
1619
## Bug Fixes
1720

1821
* `varying_args()` now uses the version from the `generics` package. This means
@@ -33,6 +36,7 @@ column names once (#107).
3336
* For multinomial regression using glmnet, `multi_predict()` now pulls the
3437
correct default penalty (#108).
3538

39+
* Confidence and prediction intervals for logistic regression were only computed the intervals for a single level. Both are now computed. (#156)
3640

3741

3842
# parsnip 0.0.1

R/linear_reg.R

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,20 @@
6363
#' \pkg{spark}
6464
#'
6565
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")}
66-
#'
66+
#'
6767
#' \pkg{keras}
6868
#'
6969
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")}
7070
#'
7171
#' When using `glmnet` models, there is the option to pass
72-
#' multiple values (or no values) to the `penalty` argument.
73-
#' This can have an effect on the model object results. When using
74-
#' the `predict()` method in these cases, the return object type
75-
#' depends on the value of `penalty`. If a single value is
76-
#' given, the results will be a simple numeric vector. When
77-
#' multiple values or no values for `penalty` are used in
78-
#' `linear_reg()`, the `predict()` method will return a data frame with
79-
#' columns `values` and `lambda`.
72+
#' multiple values (or no values) to the `penalty` argument. This
73+
#' can have an effect on the model object results. When using the
74+
#' `predict()` method in these cases, the return value depends on
75+
#' the value of `penalty`. When using `predict()`, only a single
76+
#' value of the penalty can be used. When predicting on multiple
77+
#' penalties, the `multi_predict()` function can be used. It
78+
#' returns a tibble with a list column called `.pred` that contains
79+
#' a tibble with all of the penalty results.
8080
#'
8181
#' For prediction, the `stan` engine can compute posterior
8282
#' intervals analogous to confidence and prediction intervals. In
@@ -130,7 +130,7 @@ print.linear_reg <- function(x, ...) {
130130
cat("Linear Regression Model Specification (", x$mode, ")\n\n", sep = "")
131131
model_printer(x, ...)
132132

133-
if(!is.null(x$method$fit$args)) {
133+
if (!is.null(x$method$fit$args)) {
134134
cat("Model fit template:\n")
135135
print(show_call(x))
136136
}
@@ -216,12 +216,66 @@ organize_glmnet_pred <- function(x, object) {
216216

217217
# ------------------------------------------------------------------------------
218218

219+
# For `predict` methods that use `glmnet`, we have specific methods.
220+
# Only one value of the penalty should be allowed when called by `predict()`:
221+
222+
check_penalty <- function(penalty = NULL, object, multi = FALSE) {
223+
224+
if (is.null(penalty)) {
225+
penalty <- object$fit$lambda
226+
}
227+
228+
# when using `predict()`, allow for a single lambda
229+
if (!multi) {
230+
if (length(penalty) != 1)
231+
stop("`penalty` should be a single numeric value. ",
232+
"`multi_predict()` can be used to get multiple predictions ",
233+
"per row of data.", call. = FALSE)
234+
}
235+
236+
if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda)
237+
stop("The glmnet model was fit with a single penalty value of ",
238+
object$fit$lambda, ". Predicting with a value of ",
239+
penalty, " will give incorrect results from `glmnet()`.",
240+
call. = FALSE)
241+
242+
penalty
243+
}
244+
245+
# ------------------------------------------------------------------------------
246+
# glmnet call stack for linear regression using `predict` when object has
247+
# classes "_elnet" and "model_fit":
248+
#
249+
# predict()
250+
# predict._elnet(penalty = NULL) <-- checks and sets penalty
251+
# predict.model_fit() <-- checks for extra vars in ...
252+
# predict_numeric()
253+
# predict_numeric._elnet()
254+
# predict_numeric.model_fit()
255+
# predict.elnet()
256+
257+
258+
# glmnet call stack for linear regression using `multi_predict` when object has
259+
# classes "_elnet" and "model_fit":
260+
#
261+
# multi_predict()
262+
# multi_predict._elnet(penalty = NULL)
263+
# predict._elnet(multi = TRUE) <-- checks and sets penalty
264+
# predict.model_fit() <-- checks for extra vars in ...
265+
# predict_raw()
266+
# predict_raw._elnet()
267+
# predict_raw.model_fit(opts = list(s = penalty))
268+
# predict.elnet()
269+
270+
219271
#' @export
220272
predict._elnet <-
221-
function(object, new_data, type = NULL, opts = list(), ...) {
273+
function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
222274
if (any(names(enquos(...)) == "newdata"))
223275
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
224-
276+
277+
object$spec$args$penalty <- check_penalty(penalty, object, multi)
278+
225279
object$spec <- eval_args(object$spec)
226280
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
227281
}
@@ -230,7 +284,7 @@ predict._elnet <-
230284
predict_numeric._elnet <- function(object, new_data, ...) {
231285
if (any(names(enquos(...)) == "newdata"))
232286
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
233-
287+
234288
object$spec <- eval_args(object$spec)
235289
predict_numeric.model_fit(object, new_data = new_data, ...)
236290
}
@@ -239,8 +293,9 @@ predict_numeric._elnet <- function(object, new_data, ...) {
239293
predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
240294
if (any(names(enquos(...)) == "newdata"))
241295
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
242-
296+
243297
object$spec <- eval_args(object$spec)
298+
opts$s <- object$spec$args$penalty
244299
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
245300
}
246301

@@ -251,14 +306,17 @@ multi_predict._elnet <-
251306
function(object, new_data, type = NULL, penalty = NULL, ...) {
252307
if (any(names(enquos(...)) == "newdata"))
253308
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
254-
309+
255310
dots <- list(...)
256-
if (is.null(penalty))
257-
penalty <- object$fit$lambda
258-
dots$s <- penalty
259311

260312
object$spec <- eval_args(object$spec)
261-
pred <- predict(object, new_data = new_data, type = "raw", opts = dots)
313+
314+
if (is.null(penalty)) {
315+
penalty <- object$fit$lambda
316+
}
317+
318+
pred <- predict._elnet(object, new_data = new_data, type = "raw",
319+
opts = dots, penalty = penalty, multi = TRUE)
262320
param_key <- tibble(group = colnames(pred), penalty = penalty)
263321
pred <- as_tibble(pred)
264322
pred$.row <- 1:nrow(pred)

R/logistic_reg.R

Lines changed: 74 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,14 @@
6767
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")}
6868
#'
6969
#' When using `glmnet` models, there is the option to pass
70-
#' multiple values (or no values) to the `penalty` argument.
71-
#' This can have an effect on the model object results. When using
72-
#' the `predict()` method in these cases, the return object type
73-
#' depends on the value of `penalty`. If a single value is
74-
#' given, the results will be a simple numeric vector. When
75-
#' multiple values or no values for `penalty` are used in
76-
#' `logistic_reg()`, the `predict()` method will return a data frame with
77-
#' columns `values` and `lambda`.
70+
#' multiple values (or no values) to the `penalty` argument. This
71+
#' can have an effect on the model object results. When using the
72+
#' `predict()` method in these cases, the return value depends on
73+
#' the value of `penalty`. When using `predict()`, only a single
74+
#' value of the penalty can be used. When predicting on multiple
75+
#' penalties, the `multi_predict()` function can be used. It
76+
#' returns a tibble with a list column called `.pred` that contains
77+
#' a tibble with all of the penalty results.
7878
#'
7979
#' For prediction, the `stan` engine can compute posterior
8080
#' intervals analogous to confidence and prediction intervals. In
@@ -235,41 +235,41 @@ organize_glmnet_prob <- function(x, object) {
235235
}
236236

237237
# ------------------------------------------------------------------------------
238+
# glmnet call stack for linear regression using `predict` when object has
239+
# classes "_lognet" and "model_fit" (for class predictions):
240+
#
241+
# predict()
242+
# predict._lognet(penalty = NULL) <-- checks and sets penalty
243+
# predict.model_fit() <-- checks for extra vars in ...
244+
# predict_class()
245+
# predict_class._lognet()
246+
# predict_class.model_fit()
247+
# predict.lognet()
248+
249+
250+
# glmnet call stack for linear regression using `multi_predict` when object has
251+
# classes "_lognet" and "model_fit" (for class predictions):
252+
#
253+
# multi_predict()
254+
# multi_predict._lognet(penalty = NULL)
255+
# predict._lognet(multi = TRUE) <-- checks and sets penalty
256+
# predict.model_fit() <-- checks for extra vars in ...
257+
# predict_raw()
258+
# predict_raw._lognet()
259+
# predict_raw.model_fit(opts = list(s = penalty))
260+
# predict.lognet()
238261

239-
#' @export
240-
predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) {
241-
if (any(names(enquos(...)) == "newdata"))
242-
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
243-
244-
object$spec <- eval_args(object$spec)
245-
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
246-
}
247-
248-
#' @export
249-
predict_class._lognet <- function (object, new_data, ...) {
250-
if (any(names(enquos(...)) == "newdata"))
251-
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
252-
253-
object$spec <- eval_args(object$spec)
254-
predict_class.model_fit(object, new_data = new_data, ...)
255-
}
262+
# ------------------------------------------------------------------------------
256263

257264
#' @export
258-
predict_classprob._lognet <- function (object, new_data, ...) {
265+
predict._lognet <- function (object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
259266
if (any(names(enquos(...)) == "newdata"))
260267
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
261268

262-
object$spec <- eval_args(object$spec)
263-
predict_classprob.model_fit(object, new_data = new_data, ...)
264-
}
265-
266-
#' @export
267-
predict_raw._lognet <- function (object, new_data, opts = list(), ...) {
268-
if (any(names(enquos(...)) == "newdata"))
269-
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
269+
object$spec$args$penalty <- check_penalty(penalty, object, multi)
270270

271271
object$spec <- eval_args(object$spec)
272-
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
272+
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
273273
}
274274

275275

@@ -281,23 +281,26 @@ multi_predict._lognet <-
281281
if (any(names(enquos(...)) == "newdata"))
282282
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
283283

284+
if (is_quosure(penalty))
285+
penalty <- eval_tidy(penalty)
286+
284287
dots <- list(...)
285288
if (is.null(penalty))
286-
penalty <- object$fit$lambda
289+
penalty <- eval_tidy(object$fit$lambda)
287290
dots$s <- penalty
288291

289292
if (is.null(type))
290293
type <- "class"
291-
if (!(type %in% c("class", "prob", "link"))) {
292-
stop ("`type` should be either 'class', 'link', or 'prob'.", call. = FALSE)
294+
if (!(type %in% c("class", "prob", "link", "raw"))) {
295+
stop ("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE)
293296
}
294297
if (type == "prob")
295298
dots$type <- "response"
296299
else
297300
dots$type <- type
298301

299302
object$spec <- eval_args(object$spec)
300-
pred <- predict(object, new_data = new_data, type = "raw", opts = dots)
303+
pred <- predict.model_fit(object, new_data = new_data, type = "raw", opts = dots)
301304
param_key <- tibble(group = colnames(pred), penalty = penalty)
302305
pred <- as_tibble(pred)
303306
pred$.row <- 1:nrow(pred)
@@ -321,6 +324,38 @@ multi_predict._lognet <-
321324
tibble(.pred = pred)
322325
}
323326

327+
328+
329+
330+
331+
#' @export
332+
predict_class._lognet <- function (object, new_data, ...) {
333+
if (any(names(enquos(...)) == "newdata"))
334+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
335+
336+
object$spec <- eval_args(object$spec)
337+
predict_class.model_fit(object, new_data = new_data, ...)
338+
}
339+
340+
#' @export
341+
predict_classprob._lognet <- function (object, new_data, ...) {
342+
if (any(names(enquos(...)) == "newdata"))
343+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
344+
345+
object$spec <- eval_args(object$spec)
346+
predict_classprob.model_fit(object, new_data = new_data, ...)
347+
}
348+
349+
#' @export
350+
predict_raw._lognet <- function (object, new_data, opts = list(), ...) {
351+
if (any(names(enquos(...)) == "newdata"))
352+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
353+
354+
object$spec <- eval_args(object$spec)
355+
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
356+
}
357+
358+
324359
# ------------------------------------------------------------------------------
325360

326361
#' @importFrom utils globalVariables

0 commit comments

Comments
 (0)