Skip to content

Commit f7ba069

Browse files
authored
Merge pull request #509 from tidymodels/export-glmnet-helpers
Export glmnet helpers
2 parents 7d78009 + 522b8ad commit f7ba069

File tree

8 files changed

+126
-60
lines changed

8 files changed

+126
-60
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ S3method(varying_args,model_spec)
108108
S3method(varying_args,recipe)
109109
S3method(varying_args,step)
110110
export("%>%")
111+
export(.check_glmnet_penalty_fit)
112+
export(.check_glmnet_penalty_predict)
111113
export(.cols)
112114
export(.convert_form_to_xy_fit)
113115
export(.convert_form_to_xy_new)
@@ -117,6 +119,7 @@ export(.dat)
117119
export(.facts)
118120
export(.lvls)
119121
export(.obs)
122+
export(.organize_glmnet_pred)
120123
export(.preds)
121124
export(.x)
122125
export(.y)

R/linear_reg.R

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ translate.linear_reg <- function(x, engine = x$engine, ...) {
107107
x <- translate.default(x, engine, ...)
108108

109109
if (engine == "glmnet") {
110-
check_glmnet_penalty(x)
110+
.check_glmnet_penalty_fit(x)
111111
if (any(names(x$eng_args) == "path_values")) {
112112
# Since we decouple the parsnip `penalty` argument from being the same
113113
# as the glmnet `lambda` value, `path_values` allows users to set the
@@ -192,7 +192,18 @@ check_args.linear_reg <- function(object) {
192192

193193
# ------------------------------------------------------------------------------
194194

195-
organize_glmnet_pred <- function(x, object) {
195+
#' Organize glmnet predictions
196+
#'
197+
#' This function is for developer use and organizes predictions from glmnet
198+
#' models.
199+
#'
200+
#' @param x Predictions as returned by the `predict()` method for glmnet models.
201+
#' @param object An object of class `model_fit`.
202+
#'
203+
#' @rdname glmnet_helpers_prediction
204+
#' @keywords internal
205+
#' @export
206+
.organize_glmnet_pred <- function(x, object) {
196207
if (ncol(x) == 1) {
197208
res <- x[, 1]
198209
res <- unname(res)
@@ -207,41 +218,6 @@ organize_glmnet_pred <- function(x, object) {
207218
res
208219
}
209220

210-
211-
# ------------------------------------------------------------------------------
212-
213-
# For `predict` methods that use `glmnet`, we have specific methods.
214-
# Only one value of the penalty should be allowed when called by `predict()`:
215-
216-
check_penalty <- function(penalty = NULL, object, multi = FALSE) {
217-
218-
if (is.null(penalty)) {
219-
penalty <- object$fit$lambda
220-
}
221-
222-
# when using `predict()`, allow for a single lambda
223-
if (!multi) {
224-
if (length(penalty) != 1)
225-
rlang::abort(
226-
glue::glue(
227-
"`penalty` should be a single numeric value. `multi_predict()` ",
228-
"can be used to get multiple predictions per row of data.",
229-
)
230-
)
231-
}
232-
233-
if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda)
234-
rlang::abort(
235-
glue::glue(
236-
"The glmnet model was fit with a single penalty value of ",
237-
"{object$fit$lambda}. Predicting with a value of {penalty} ",
238-
"will give incorrect results from `glmnet()`."
239-
)
240-
)
241-
242-
penalty
243-
}
244-
245221
# ------------------------------------------------------------------------------
246222
# glmnet call stack for linear regression using `predict` when object has
247223
# classes "_elnet" and "model_fit":
@@ -279,7 +255,7 @@ predict._elnet <-
279255
penalty <- object$spec$args$penalty
280256
}
281257

282-
object$spec$args$penalty <- check_penalty(penalty, object, multi)
258+
object$spec$args$penalty <- .check_glmnet_penalty_predict(penalty, object, multi)
283259

284260
object$spec <- eval_args(object$spec)
285261
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)

R/linear_reg_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ set_pred(
163163
type = "numeric",
164164
value = list(
165165
pre = NULL,
166-
post = organize_glmnet_pred,
166+
post = .organize_glmnet_pred,
167167
func = c(fun = "predict"),
168168
args =
169169
list(

R/logistic_reg.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ translate.logistic_reg <- function(x, engine = x$engine, ...) {
109109
arg_names <- names(arg_vals)
110110

111111
if (engine == "glmnet") {
112-
check_glmnet_penalty(x)
112+
.check_glmnet_penalty_fit(x)
113113
if (any(names(x$eng_args) == "path_values")) {
114114
# Since we decouple the parsnip `penalty` argument from being the same
115115
# as the glmnet `lambda` value, `path_values` allows users to set the
@@ -296,7 +296,7 @@ predict._lognet <- function(object, new_data, type = NULL, opts = list(), penalt
296296
penalty <- object$spec$args$penalty
297297
}
298298

299-
object$spec$args$penalty <- check_penalty(penalty, object, multi)
299+
object$spec$args$penalty <- .check_glmnet_penalty_predict(penalty, object, multi)
300300

301301
object$spec <- eval_args(object$spec)
302302
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)

R/misc.R

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,26 @@ stan_conf_int <- function(object, newdata) {
323323
rlang::eval_tidy(fn)
324324
}
325325

326-
check_glmnet_penalty <- function(x) {
326+
# ------------------------------------------------------------------------------
327+
328+
329+
#' Helper functions for checking the penalty of glmnet models
330+
#'
331+
#' @description
332+
#' These functions are for developer use.
333+
#'
334+
#' `.check_glmnet_penalty_fit()` checks that the model specification for fitting a
335+
#' glmnet model contains a single value.
336+
#'
337+
#' `.check_glmnet_penalty_predict()` checks that the penalty value used for prediction is valid.
338+
#' If called by `predict()`, it needs to be a single value. Multiple values are
339+
#' allowed for `multi_predict()`.
340+
#'
341+
#' @param x An object of class `model_spec`.
342+
#' @rdname glmnet_helpers
343+
#' @keywords internal
344+
#' @export
345+
.check_glmnet_penalty_fit <- function(x) {
327346
pen <- rlang::eval_tidy(x$args$penalty)
328347

329348
if (length(pen) != 1) {
@@ -335,3 +354,39 @@ check_glmnet_penalty <- function(x) {
335354
))
336355
}
337356
}
357+
358+
#' @param penalty A penalty value to check.
359+
#' @param object An object of class `model_fit`.
360+
#' @param multi A logical indicating if multiple values are allowed.
361+
#'
362+
#' @rdname glmnet_helpers
363+
#' @keywords internal
364+
#' @export
365+
.check_glmnet_penalty_predict <- function(penalty = NULL, object, multi = FALSE) {
366+
367+
if (is.null(penalty)) {
368+
penalty <- object$fit$lambda
369+
}
370+
371+
# when using `predict()`, allow for a single lambda
372+
if (!multi) {
373+
if (length(penalty) != 1)
374+
rlang::abort(
375+
glue::glue(
376+
"`penalty` should be a single numeric value. `multi_predict()` ",
377+
"can be used to get multiple predictions per row of data.",
378+
)
379+
)
380+
}
381+
382+
if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda)
383+
rlang::abort(
384+
glue::glue(
385+
"The glmnet model was fit with a single penalty value of ",
386+
"{object$fit$lambda}. Predicting with a value of {penalty} ",
387+
"will give incorrect results from `glmnet()`."
388+
)
389+
)
390+
391+
penalty
392+
}

R/multinom_reg.R

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ predict._multnet <-
222222
penalty <- object$spec$args$penalty
223223
}
224224

225-
object$spec$args$penalty <- check_penalty(penalty, object, multi)
225+
object$spec$args$penalty <- .check_glmnet_penalty_predict(penalty, object, multi)
226226

227227
object$spec <- eval_args(object$spec)
228228
res <- predict.model_fit(
@@ -317,20 +317,3 @@ predict_raw._multnet <- function(object, new_data, opts = list(), ...) {
317317
object$spec <- eval_args(object$spec)
318318
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
319319
}
320-
321-
322-
323-
# ------------------------------------------------------------------------------
324-
325-
# This checks as a pre-processor in the model data object
326-
check_glmnet_lambda <- function(dat, object) {
327-
if (length(object$fit$lambda) > 1)
328-
rlang::abort(
329-
glue::glue(
330-
"`predict()` doesn't work with multiple penalties (i.e. lambdas). ",
331-
"Please specify a single value using `penalty = some_value` or use ",
332-
"`multi_predict()` to get multiple predictions per row of data."
333-
)
334-
)
335-
dat
336-
}

man/glmnet_helpers.Rd

Lines changed: 31 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/glmnet_helpers_prediction.Rd

Lines changed: 18 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)