Skip to content

Commit cf8f3c7

Browse files
committed
move check_penalty()
1 parent 2a8da60 commit cf8f3c7

File tree

2 files changed

+34
-34
lines changed

2 files changed

+34
-34
lines changed

R/linear_reg.R

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -208,40 +208,6 @@ organize_glmnet_pred <- function(x, object) {
208208
}
209209

210210

211-
# ------------------------------------------------------------------------------
212-
213-
# For `predict` methods that use `glmnet`, we have specific methods.
214-
# Only one value of the penalty should be allowed when called by `predict()`:
215-
216-
check_penalty <- function(penalty = NULL, object, multi = FALSE) {
217-
218-
if (is.null(penalty)) {
219-
penalty <- object$fit$lambda
220-
}
221-
222-
# when using `predict()`, allow for a single lambda
223-
if (!multi) {
224-
if (length(penalty) != 1)
225-
rlang::abort(
226-
glue::glue(
227-
"`penalty` should be a single numeric value. `multi_predict()` ",
228-
"can be used to get multiple predictions per row of data.",
229-
)
230-
)
231-
}
232-
233-
if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda)
234-
rlang::abort(
235-
glue::glue(
236-
"The glmnet model was fit with a single penalty value of ",
237-
"{object$fit$lambda}. Predicting with a value of {penalty} ",
238-
"will give incorrect results from `glmnet()`."
239-
)
240-
)
241-
242-
penalty
243-
}
244-
245211
# ------------------------------------------------------------------------------
246212
# glmnet call stack for linear regression using `predict` when object has
247213
# classes "_elnet" and "model_fit":

R/misc.R

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,40 @@ stan_conf_int <- function(object, newdata) {
323323
rlang::eval_tidy(fn)
324324
}
325325

326+
# ------------------------------------------------------------------------------
327+
328+
# For `predict` methods that use `glmnet`, we have specific methods.
329+
# Only one value of the penalty should be allowed when called by `predict()`:
330+
331+
check_penalty <- function(penalty = NULL, object, multi = FALSE) {
332+
333+
if (is.null(penalty)) {
334+
penalty <- object$fit$lambda
335+
}
336+
337+
# when using `predict()`, allow for a single lambda
338+
if (!multi) {
339+
if (length(penalty) != 1)
340+
rlang::abort(
341+
glue::glue(
342+
"`penalty` should be a single numeric value. `multi_predict()` ",
343+
"can be used to get multiple predictions per row of data.",
344+
)
345+
)
346+
}
347+
348+
if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda)
349+
rlang::abort(
350+
glue::glue(
351+
"The glmnet model was fit with a single penalty value of ",
352+
"{object$fit$lambda}. Predicting with a value of {penalty} ",
353+
"will give incorrect results from `glmnet()`."
354+
)
355+
)
356+
357+
penalty
358+
}
359+
326360
check_glmnet_penalty <- function(x) {
327361
pen <- rlang::eval_tidy(x$args$penalty)
328362

0 commit comments

Comments
 (0)