@@ -325,9 +325,41 @@ stan_conf_int <- function(object, newdata) {
325325
326326# ------------------------------------------------------------------------------
327327
328- # For `predict` methods that use `glmnet`, we have specific methods.
329- # Only one value of the penalty should be allowed when called by `predict()`:
330328
329+ # ' Helper functions for checking the penalty of glmnet models
330+ # '
331+ # ' @description
332+ # ' `check_glmnet_penalty()` checks that the model specification for fitting a
333+ # ' glmnet model contains a single value.
334+ # '
335+ # ' `check_penalty()` checks that the penalty value used for prediction is valid.
336+ # ' If called by `predict()`, it needs to be a single value. Multiple values are
337+ # ' allowed for `multi_predict()`.
338+ # '
339+ # ' @param x An object of class `model_spec`.
340+ # ' @rdname glmnet_helpers
341+ # ' @keywords internal
342+ # ' @export
343+ check_glmnet_penalty <- function (x ) {
344+ pen <- rlang :: eval_tidy(x $ args $ penalty )
345+
346+ if (length(pen ) != 1 ) {
347+ rlang :: abort(c(
348+ " For the glmnet engine, `penalty` must be a single number (or a value of `tune()`)." ,
349+ glue :: glue(" There are {length(pen)} values for `penalty`." ),
350+ " To try multiple values for total regularization, use the tune package." ,
351+ " To predict multiple penalties, use `multi_predict()`"
352+ ))
353+ }
354+ }
355+
356+ # ' @param penalty A penalty value to check.
357+ # ' @param object An object of class `model_fit`.
358+ # ' @param multi A logical indicating if multiple values are allowed.
359+ # '
360+ # ' @rdname glmnet_helpers
361+ # ' @keywords internal
362+ # ' @export
331363check_penalty <- function (penalty = NULL , object , multi = FALSE ) {
332364
333365 if (is.null(penalty )) {
@@ -356,16 +388,3 @@ check_penalty <- function(penalty = NULL, object, multi = FALSE) {
356388
357389 penalty
358390}
359-
360- check_glmnet_penalty <- function (x ) {
361- pen <- rlang :: eval_tidy(x $ args $ penalty )
362-
363- if (length(pen ) != 1 ) {
364- rlang :: abort(c(
365- " For the glmnet engine, `penalty` must be a single number (or a value of `tune()`)." ,
366- glue :: glue(" There are {length(pen)} values for `penalty`." ),
367- " To try multiple values for total regularization, use the tune package." ,
368- " To predict multiple penalties, use `multi_predict()`"
369- ))
370- }
371- }
0 commit comments