Skip to content

Commit db16e3e

Browse files
committed
document and export glmnet helpers
1 parent cf8f3c7 commit db16e3e

File tree

4 files changed

+65
-16
lines changed

4 files changed

+65
-16
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ export(augment)
122122
export(boost_tree)
123123
export(check_empty_ellipse)
124124
export(check_final_param)
125+
export(check_glmnet_penalty)
126+
export(check_penalty)
125127
export(contr_one_hot)
126128
export(control_parsnip)
127129
export(convert_stan_interval)

R/linear_reg.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ organize_glmnet_pred <- function(x, object) {
207207
res
208208
}
209209

210-
211210
# ------------------------------------------------------------------------------
212211
# glmnet call stack for linear regression using `predict` when object has
213212
# classes "_elnet" and "model_fit":

R/misc.R

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,41 @@ stan_conf_int <- function(object, newdata) {
325325

326326
# ------------------------------------------------------------------------------
327327

328-
# For `predict` methods that use `glmnet`, we have specific methods.
329-
# Only one value of the penalty should be allowed when called by `predict()`:
330328

329+
#' Helper functions for checking the penalty of glmnet models
330+
#'
331+
#' @description
332+
#' `check_glmnet_penalty()` checks that the model specification for fitting a
333+
#' glmnet model contains a single value.
334+
#'
335+
#' `check_penalty()` checks that the penalty value used for prediction is valid.
336+
#' If called by `predict()`, it needs to be a single value. Multiple values are
337+
#' allowed for `multi_predict()`.
338+
#'
339+
#' @param x An object of class `model_spec`.
340+
#' @rdname glmnet_helpers
341+
#' @keywords internal
342+
#' @export
343+
check_glmnet_penalty <- function(x) {
344+
pen <- rlang::eval_tidy(x$args$penalty)
345+
346+
if (length(pen) != 1) {
347+
rlang::abort(c(
348+
"For the glmnet engine, `penalty` must be a single number (or a value of `tune()`).",
349+
glue::glue("There are {length(pen)} values for `penalty`."),
350+
"To try multiple values for total regularization, use the tune package.",
351+
"To predict multiple penalties, use `multi_predict()`"
352+
))
353+
}
354+
}
355+
356+
#' @param penalty A penalty value to check.
357+
#' @param object An object of class `model_fit`.
358+
#' @param multi A logical indicating if multiple values are allowed.
359+
#'
360+
#' @rdname glmnet_helpers
361+
#' @keywords internal
362+
#' @export
331363
check_penalty <- function(penalty = NULL, object, multi = FALSE) {
332364

333365
if (is.null(penalty)) {
@@ -356,16 +388,3 @@ check_penalty <- function(penalty = NULL, object, multi = FALSE) {
356388

357389
penalty
358390
}
359-
360-
check_glmnet_penalty <- function(x) {
361-
pen <- rlang::eval_tidy(x$args$penalty)
362-
363-
if (length(pen) != 1) {
364-
rlang::abort(c(
365-
"For the glmnet engine, `penalty` must be a single number (or a value of `tune()`).",
366-
glue::glue("There are {length(pen)} values for `penalty`."),
367-
"To try multiple values for total regularization, use the tune package.",
368-
"To predict multiple penalties, use `multi_predict()`"
369-
))
370-
}
371-
}

man/glmnet_helpers.Rd

Lines changed: 29 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)