6767# ' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")}
6868# '
6969# ' When using `glmnet` models, there is the option to pass
70- # ' multiple values (or no values) to the `penalty` argument.
71- # ' This can have an effect on the model object results. When using
72- # ' the `predict()` method in these cases, the return object type
73- # ' depends on the value of `penalty`. If a single value is
74- # ' given, the results will be a simple numeric vector. When
75- # ' multiple values or no values for `penalty` are used in
76- # ' `logistic_reg()`, the `predict()` method will return a data frame with
77- # ' columns `values` and `lambda` .
70+ # ' multiple values (or no values) to the `penalty` argument. This
71+ # ' can have an effect on the model object results. When using the
72+ # ' `predict()` method in these cases, the return value depends on
73+ # ' the value of `penalty`. When using `predict()`, only a single
74+ # ' value of the penalty can be used. When predicting on multiple
75+ # ' penalties, the `multi_predict()` function can be used. It
76+ # ' returns a tibble with a list column called `.pred` that contains
77+ # ' a tibble with all of the penalty results .
7878# '
7979# ' For prediction, the `stan` engine can compute posterior
8080# ' intervals analogous to confidence and prediction intervals. In
@@ -235,41 +235,41 @@ organize_glmnet_prob <- function(x, object) {
235235}
236236
237237# ------------------------------------------------------------------------------
238+ # glmnet call stack for linear regression using `predict` when object has
239+ # classes "_lognet" and "model_fit" (for class predictions):
240+ #
241+ # predict()
242+ # predict._lognet(penalty = NULL) <-- checks and sets penalty
243+ # predict.model_fit() <-- checks for extra vars in ...
244+ # predict_class()
245+ # predict_class._lognet()
246+ # predict_class.model_fit()
247+ # predict.lognet()
248+
249+
250+ # glmnet call stack for linear regression using `multi_predict` when object has
251+ # classes "_lognet" and "model_fit" (for class predictions):
252+ #
253+ # multi_predict()
254+ # multi_predict._lognet(penalty = NULL)
255+ # predict._lognet(multi = TRUE) <-- checks and sets penalty
256+ # predict.model_fit() <-- checks for extra vars in ...
257+ # predict_raw()
258+ # predict_raw._lognet()
259+ # predict_raw.model_fit(opts = list(s = penalty))
260+ # predict.lognet()
238261
239- # ' @export
240- predict._lognet <- function (object , new_data , type = NULL , opts = list (), ... ) {
241- if (any(names(enquos(... )) == " newdata" ))
242- stop(" Did you mean to use `new_data` instead of `newdata`?" , call. = FALSE )
243-
244- object $ spec <- eval_args(object $ spec )
245- predict.model_fit(object , new_data = new_data , type = type , opts = opts , ... )
246- }
247-
248- # ' @export
249- predict_class._lognet <- function (object , new_data , ... ) {
250- if (any(names(enquos(... )) == " newdata" ))
251- stop(" Did you mean to use `new_data` instead of `newdata`?" , call. = FALSE )
252-
253- object $ spec <- eval_args(object $ spec )
254- predict_class.model_fit(object , new_data = new_data , ... )
255- }
262+ # ------------------------------------------------------------------------------
256263
257264# ' @export
258- predict_classprob ._lognet <- function (object , new_data , ... ) {
265+ predict ._lognet <- function (object , new_data , type = NULL , opts = list (), penalty = NULL , multi = FALSE , ... ) {
259266 if (any(names(enquos(... )) == " newdata" ))
260267 stop(" Did you mean to use `new_data` instead of `newdata`?" , call. = FALSE )
261268
262- object $ spec <- eval_args(object $ spec )
263- predict_classprob.model_fit(object , new_data = new_data , ... )
264- }
265-
266- # ' @export
267- predict_raw._lognet <- function (object , new_data , opts = list (), ... ) {
268- if (any(names(enquos(... )) == " newdata" ))
269- stop(" Did you mean to use `new_data` instead of `newdata`?" , call. = FALSE )
269+ object $ spec $ args $ penalty <- check_penalty(penalty , object , multi )
270270
271271 object $ spec <- eval_args(object $ spec )
272- predict_raw .model_fit(object , new_data = new_data , opts = opts , ... )
272+ predict .model_fit(object , new_data = new_data , type = type , opts = opts , ... )
273273}
274274
275275
@@ -281,23 +281,26 @@ multi_predict._lognet <-
281281 if (any(names(enquos(... )) == " newdata" ))
282282 stop(" Did you mean to use `new_data` instead of `newdata`?" , call. = FALSE )
283283
284+ if (is_quosure(penalty ))
285+ penalty <- eval_tidy(penalty )
286+
284287 dots <- list (... )
285288 if (is.null(penalty ))
286- penalty <- object $ fit $ lambda
289+ penalty <- eval_tidy( object $ fit $ lambda )
287290 dots $ s <- penalty
288291
289292 if (is.null(type ))
290293 type <- " class"
291- if (! (type %in% c(" class" , " prob" , " link" ))) {
292- stop (" `type` should be either 'class', 'link', or 'prob'." , call. = FALSE )
294+ if (! (type %in% c(" class" , " prob" , " link" , " raw " ))) {
295+ stop (" `type` should be either 'class', 'link', 'raw', or 'prob'." , call. = FALSE )
293296 }
294297 if (type == " prob" )
295298 dots $ type <- " response"
296299 else
297300 dots $ type <- type
298301
299302 object $ spec <- eval_args(object $ spec )
300- pred <- predict(object , new_data = new_data , type = " raw" , opts = dots )
303+ pred <- predict.model_fit (object , new_data = new_data , type = " raw" , opts = dots )
301304 param_key <- tibble(group = colnames(pred ), penalty = penalty )
302305 pred <- as_tibble(pred )
303306 pred $ .row <- 1 : nrow(pred )
@@ -321,6 +324,38 @@ multi_predict._lognet <-
321324 tibble(.pred = pred )
322325 }
323326
327+
328+
329+
330+
331+ # ' @export
332+ predict_class._lognet <- function (object , new_data , ... ) {
333+ if (any(names(enquos(... )) == " newdata" ))
334+ stop(" Did you mean to use `new_data` instead of `newdata`?" , call. = FALSE )
335+
336+ object $ spec <- eval_args(object $ spec )
337+ predict_class.model_fit(object , new_data = new_data , ... )
338+ }
339+
340+ # ' @export
341+ predict_classprob._lognet <- function (object , new_data , ... ) {
342+ if (any(names(enquos(... )) == " newdata" ))
343+ stop(" Did you mean to use `new_data` instead of `newdata`?" , call. = FALSE )
344+
345+ object $ spec <- eval_args(object $ spec )
346+ predict_classprob.model_fit(object , new_data = new_data , ... )
347+ }
348+
349+ # ' @export
350+ predict_raw._lognet <- function (object , new_data , opts = list (), ... ) {
351+ if (any(names(enquos(... )) == " newdata" ))
352+ stop(" Did you mean to use `new_data` instead of `newdata`?" , call. = FALSE )
353+
354+ object $ spec <- eval_args(object $ spec )
355+ predict_raw.model_fit(object , new_data = new_data , opts = opts , ... )
356+ }
357+
358+
324359# ------------------------------------------------------------------------------
325360
326361# ' @importFrom utils globalVariables
0 commit comments