Skip to content

Commit 3aa1e6c

Browse files
hfricksimonpcouch
andauthored
encapsulate glmnet formatting code (#867)
* encapsulate formatting code * Apply suggestions from code review Co-authored-by: Simon P. Couch <simonpatrickcouch@gmail.com> --------- Co-authored-by: Simon P. Couch <simonpatrickcouch@gmail.com>
1 parent 3352b39 commit 3aa1e6c

File tree

3 files changed

+93
-71
lines changed

3 files changed

+93
-71
lines changed

R/linear_reg.R

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -241,20 +241,25 @@ multi_predict._elnet <-
241241

242242
pred <- predict._elnet(object, new_data = new_data, type = "raw",
243243
opts = dots, penalty = penalty, multi = TRUE)
244-
param_key <- tibble(group = colnames(pred), penalty = penalty)
245-
pred <- as_tibble(pred)
246-
pred$.row <- 1:nrow(pred)
247-
pred <- gather(pred, group, .pred, -.row)
248-
if (utils::packageVersion("dplyr") >= "1.0.99.9000") {
249-
pred <- full_join(param_key, pred, by = "group", multiple = "all")
250-
} else {
251-
pred <- full_join(param_key, pred, by = "group")
252-
}
253-
pred$group <- NULL
254-
pred <- arrange(pred, .row, penalty)
255-
.row <- pred$.row
256-
pred$.row <- NULL
257-
pred <- split(pred, .row)
258-
names(pred) <- NULL
259-
tibble(.pred = pred)
244+
245+
format_glmnet_multi_linear_reg(pred, penalty = penalty)
260246
}
247+
248+
format_glmnet_multi_linear_reg <- function(pred, penalty) {
249+
param_key <- tibble(group = colnames(pred), penalty = penalty)
250+
pred <- as_tibble(pred)
251+
pred$.row <- 1:nrow(pred)
252+
pred <- gather(pred, group, .pred, -.row)
253+
if (utils::packageVersion("dplyr") >= "1.0.99.9000") {
254+
pred <- full_join(param_key, pred, by = "group", multiple = "all")
255+
} else {
256+
pred <- full_join(param_key, pred, by = "group")
257+
}
258+
pred$group <- NULL
259+
pred <- arrange(pred, .row, penalty)
260+
.row <- pred$.row
261+
pred$.row <- NULL
262+
pred <- split(pred, .row)
263+
names(pred) <- NULL
264+
tibble(.pred = pred)
265+
}

R/logistic_reg.R

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -285,34 +285,41 @@ multi_predict._lognet <-
285285
pred <- predict._lognet(object, new_data = new_data, type = "raw",
286286
opts = dots, penalty = penalty, multi = TRUE)
287287

288-
param_key <- tibble(group = colnames(pred), penalty = penalty)
289-
pred <- as_tibble(pred)
290-
pred$.row <- 1:nrow(pred)
291-
pred <- gather(pred, group, .pred_class, -.row)
292-
if (dots$type == "class") {
293-
pred[[".pred_class"]] <- factor(pred[[".pred_class"]], levels = object$lvl)
294-
} else {
295-
if (dots$type == "response") {
296-
pred[[".pred2"]] <- 1 - pred[[".pred_class"]]
297-
names(pred) <- c(".row", "group", paste0(".pred_", rev(object$lvl)))
298-
pred <- pred[, c(".row", "group", paste0(".pred_", object$lvl))]
299-
}
300-
}
301-
if (utils::packageVersion("dplyr") >= "1.0.99.9000") {
302-
pred <- full_join(param_key, pred, by = "group", multiple = "all")
303-
} else {
304-
pred <- full_join(param_key, pred, by = "group")
305-
}
306-
pred$group <- NULL
307-
pred <- arrange(pred, .row, penalty)
308-
.row <- pred$.row
309-
pred$.row <- NULL
310-
pred <- split(pred, .row)
311-
names(pred) <- NULL
312-
tibble(.pred = pred)
288+
format_glmnet_multi_logistic_reg(
289+
pred,
290+
penalty,
291+
type = dots$type,
292+
lvl = object$lvl
293+
)
313294
}
314295

315-
296+
format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) {
297+
param_key <- tibble(group = colnames(pred), penalty = penalty)
298+
pred <- as_tibble(pred)
299+
pred$.row <- 1:nrow(pred)
300+
pred <- gather(pred, group, .pred_class, -.row)
301+
if (type == "class") {
302+
pred[[".pred_class"]] <- factor(pred[[".pred_class"]], levels = lvl)
303+
} else {
304+
if (type == "response") {
305+
pred[[".pred2"]] <- 1 - pred[[".pred_class"]]
306+
names(pred) <- c(".row", "group", paste0(".pred_", rev(lvl)))
307+
pred <- pred[, c(".row", "group", paste0(".pred_", lvl))]
308+
}
309+
}
310+
if (utils::packageVersion("dplyr") >= "1.0.99.9000") {
311+
pred <- full_join(param_key, pred, by = "group", multiple = "all")
312+
} else {
313+
pred <- full_join(param_key, pred, by = "group")
314+
}
315+
pred$group <- NULL
316+
pred <- arrange(pred, .row, penalty)
317+
.row <- pred$.row
318+
pred$.row <- NULL
319+
pred <- split(pred, .row)
320+
names(pred) <- NULL
321+
tibble(.pred = pred)
322+
}
316323

317324

318325

R/multinom_reg.R

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -224,35 +224,13 @@ multi_predict._multnet <-
224224
pred <- predict._multnet(object, new_data = new_data, type = "raw",
225225
opts = dots, penalty = penalty, multi = TRUE)
226226

227-
format_probs <- function(x) {
228-
x <- as_tibble(x)
229-
names(x) <- paste0(".pred_", names(x))
230-
nms <- names(x)
231-
x$.row <- 1:nrow(x)
232-
x[, c(".row", nms)]
233-
}
234-
235-
if (type == "prob") {
236-
pred <- apply(pred, 3, format_probs)
237-
names(pred) <- NULL
238-
pred <- map_dfr(pred, function(x) x)
239-
pred$penalty <- rep(penalty, each = nrow(new_data))
240-
pred <- dplyr::relocate(pred, penalty)
241-
} else {
242-
pred <-
243-
tibble(
244-
.row = rep(1:nrow(new_data), length(penalty)),
245-
penalty = rep(penalty, each = nrow(new_data)),
246-
.pred_class = factor(as.vector(pred), levels = object$lvl)
247-
)
248-
}
249-
250-
pred <- arrange(pred, .row, penalty)
251-
.row <- pred$.row
252-
pred$.row <- NULL
253-
pred <- split(pred, .row)
254-
names(pred) <- NULL
255-
tibble(.pred = pred)
227+
format_glmnet_multi_multinom_reg(
228+
pred,
229+
penalty = penalty,
230+
type = type,
231+
n_rows = nrow(new_data),
232+
lvl = object$lvl
233+
)
256234
}
257235

258236
#' @export
@@ -273,3 +251,35 @@ predict_raw._multnet <- function(object, new_data, opts = list(), ...) {
273251
opts$s <- object$spec$args$penalty
274252
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
275253
}
254+
255+
format_glmnet_multi_multinom_reg <- function(pred, penalty, type, n_rows, lvl) {
256+
format_probs <- function(x) {
257+
x <- as_tibble(x)
258+
names(x) <- paste0(".pred_", names(x))
259+
nms <- names(x)
260+
x$.row <- 1:nrow(x)
261+
x[, c(".row", nms)]
262+
}
263+
264+
if (type == "prob") {
265+
pred <- apply(pred, 3, format_probs)
266+
names(pred) <- NULL
267+
pred <- map_dfr(pred, function(x) x)
268+
pred$penalty <- rep(penalty, each = n_rows)
269+
pred <- dplyr::relocate(pred, penalty)
270+
} else {
271+
pred <-
272+
tibble(
273+
.row = rep(1:n_rows, length(penalty)),
274+
penalty = rep(penalty, each = n_rows),
275+
.pred_class = factor(as.vector(pred), levels = lvl)
276+
)
277+
}
278+
279+
pred <- arrange(pred, .row, penalty)
280+
.row <- pred$.row
281+
pred$.row <- NULL
282+
pred <- split(pred, .row)
283+
names(pred) <- NULL
284+
tibble(.pred = pred)
285+
}

0 commit comments

Comments
 (0)