@@ -285,34 +285,41 @@ multi_predict._lognet <-
285285 pred <- predict._lognet(object , new_data = new_data , type = " raw" ,
286286 opts = dots , penalty = penalty , multi = TRUE )
287287
288- param_key <- tibble(group = colnames(pred ), penalty = penalty )
289- pred <- as_tibble(pred )
290- pred $ .row <- 1 : nrow(pred )
291- pred <- gather(pred , group , .pred_class , - .row )
292- if (dots $ type == " class" ) {
293- pred [[" .pred_class" ]] <- factor (pred [[" .pred_class" ]], levels = object $ lvl )
294- } else {
295- if (dots $ type == " response" ) {
296- pred [[" .pred2" ]] <- 1 - pred [[" .pred_class" ]]
297- names(pred ) <- c(" .row" , " group" , paste0(" .pred_" , rev(object $ lvl )))
298- pred <- pred [, c(" .row" , " group" , paste0(" .pred_" , object $ lvl ))]
299- }
300- }
301- if (utils :: packageVersion(" dplyr" ) > = " 1.0.99.9000" ) {
302- pred <- full_join(param_key , pred , by = " group" , multiple = " all" )
303- } else {
304- pred <- full_join(param_key , pred , by = " group" )
305- }
306- pred $ group <- NULL
307- pred <- arrange(pred , .row , penalty )
308- .row <- pred $ .row
309- pred $ .row <- NULL
310- pred <- split(pred , .row )
311- names(pred ) <- NULL
312- tibble(.pred = pred )
288+ format_glmnet_multi_logistic_reg(
289+ pred ,
290+ penalty ,
291+ type = dots $ type ,
292+ lvl = object $ lvl
293+ )
313294 }
314295
315-
296+ format_glmnet_multi_logistic_reg <- function (pred , penalty , type , lvl ) {
297+ param_key <- tibble(group = colnames(pred ), penalty = penalty )
298+ pred <- as_tibble(pred )
299+ pred $ .row <- 1 : nrow(pred )
300+ pred <- gather(pred , group , .pred_class , - .row )
301+ if (type == " class" ) {
302+ pred [[" .pred_class" ]] <- factor (pred [[" .pred_class" ]], levels = lvl )
303+ } else {
304+ if (type == " response" ) {
305+ pred [[" .pred2" ]] <- 1 - pred [[" .pred_class" ]]
306+ names(pred ) <- c(" .row" , " group" , paste0(" .pred_" , rev(lvl )))
307+ pred <- pred [, c(" .row" , " group" , paste0(" .pred_" , lvl ))]
308+ }
309+ }
310+ if (utils :: packageVersion(" dplyr" ) > = " 1.0.99.9000" ) {
311+ pred <- full_join(param_key , pred , by = " group" , multiple = " all" )
312+ } else {
313+ pred <- full_join(param_key , pred , by = " group" )
314+ }
315+ pred $ group <- NULL
316+ pred <- arrange(pred , .row , penalty )
317+ .row <- pred $ .row
318+ pred $ .row <- NULL
319+ pred <- split(pred , .row )
320+ names(pred ) <- NULL
321+ tibble(.pred = pred )
322+ }
316323
317324
318325
0 commit comments