@@ -206,92 +206,14 @@ organize_glmnet_prob <- function(x, object) {
206206 res
207207}
208208
209- # ------------------------------------------------------------------------------
210- # glmnet call stack for logistic regression using `predict` when object has
211- # classes "_lognet" and "model_fit" (for class predictions):
212- #
213- # predict()
214- # predict._lognet(penalty = NULL) <-- checks and sets penalty
215- # predict.model_fit() <-- checks for extra vars in ...
216- # predict_class()
217- # predict_class._lognet()
218- # predict_class.model_fit()
219- # predict.lognet()
220-
221-
222- # glmnet call stack for logistic regression using `multi_predict` when object has
223- # classes "_lognet" and "model_fit" (for class predictions):
224- #
225- # multi_predict()
226- # multi_predict._lognet(penalty = NULL)
227- # predict._lognet(multi = TRUE) <-- checks and sets penalty
228- # predict.model_fit() <-- checks for extra vars in ...
229- # predict_raw()
230- # predict_raw._lognet()
231- # predict_raw.model_fit(opts = list(s = penalty))
232- # predict.lognet()
233-
234209# ------------------------------------------------------------------------------
235210
236211# ' @export
237- predict._lognet <- function (object , new_data , type = NULL , opts = list (), penalty = NULL , multi = FALSE , ... ) {
238- if (any(names(enquos(... )) == " newdata" ))
239- rlang :: abort(" Did you mean to use `new_data` instead of `newdata`?" )
240-
241- # See discussion in https://github.com/tidymodels/parsnip/issues/195
242- if (is.null(penalty ) & ! is.null(object $ spec $ args $ penalty )) {
243- penalty <- object $ spec $ args $ penalty
244- }
245-
246- object $ spec $ args $ penalty <- .check_glmnet_penalty_predict(penalty , object , multi )
247-
248- object $ spec <- eval_args(object $ spec )
249- predict.model_fit(object , new_data = new_data , type = type , opts = opts , ... )
250- }
251-
212+ predict._lognet <- predict_glmnet
252213
253214# ' @export
254215# ' @rdname multi_predict
255- multi_predict._lognet <-
256- function (object , new_data , type = NULL , penalty = NULL , ... ) {
257- if (any(names(enquos(... )) == " newdata" ))
258- rlang :: abort(" Did you mean to use `new_data` instead of `newdata`?" )
259-
260- if (is_quosure(penalty ))
261- penalty <- eval_tidy(penalty )
262-
263- dots <- list (... )
264-
265- if (is.null(penalty )) {
266- # See discussion in https://github.com/tidymodels/parsnip/issues/195
267- if (! is.null(object $ spec $ args $ penalty )) {
268- penalty <- object $ spec $ args $ penalty
269- } else {
270- penalty <- object $ fit $ lambda
271- }
272- }
273-
274- if (is.null(type ))
275- type <- " class"
276- if (! (type %in% c(" class" , " prob" , " link" , " raw" ))) {
277- rlang :: abort(" `type` should be either 'class', 'link', 'raw', or 'prob'." )
278- }
279- if (type == " prob" )
280- dots $ type <- " response"
281- else
282- dots $ type <- type
283-
284- object $ spec <- eval_args(object $ spec )
285- pred <- predict._lognet(object , new_data = new_data , type = " raw" ,
286- opts = dots , penalty = penalty , multi = TRUE )
287-
288- format_glmnet_multi_logistic_reg(
289- pred ,
290- penalty ,
291- type = dots $ type ,
292- lvl = object $ lvl
293- )
294- }
216+ multi_predict._lognet <- multi_predict_glmnet
295217
296218format_glmnet_multi_logistic_reg <- function (pred , penalty , type , lvl ) {
297219 param_key <- tibble(group = colnames(pred ), penalty = penalty )
@@ -324,32 +246,13 @@ format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) {
324246
325247
326248# ' @export
327- predict_class._lognet <- function (object , new_data , ... ) {
328- if (any(names(enquos(... )) == " newdata" ))
329- rlang :: abort(" Did you mean to use `new_data` instead of `newdata`?" )
330-
331- object $ spec <- eval_args(object $ spec )
332- predict_class.model_fit(object , new_data = new_data , ... )
333- }
249+ predict_class._lognet <- predict_class_glmnet
334250
335251# ' @export
336- predict_classprob._lognet <- function (object , new_data , ... ) {
337- if (any(names(enquos(... )) == " newdata" ))
338- rlang :: abort(" Did you mean to use `new_data` instead of `newdata`?" )
339-
340- object $ spec <- eval_args(object $ spec )
341- predict_classprob.model_fit(object , new_data = new_data , ... )
342- }
252+ predict_classprob._lognet <- predict_classprob_glmnet
343253
344254# ' @export
345- predict_raw._lognet <- function (object , new_data , opts = list (), ... ) {
346- if (any(names(enquos(... )) == " newdata" ))
347- rlang :: abort(" Did you mean to use `new_data` instead of `newdata`?" )
348-
349- object $ spec <- eval_args(object $ spec )
350- opts $ s <- object $ spec $ args $ penalty
351- predict_raw.model_fit(object , new_data = new_data , opts = opts , ... )
352- }
255+ predict_raw._lognet <- predict_raw_glmnet
353256
354257# ------------------------------------------------------------------------------
355258
0 commit comments