|
| 1 | +#' Create a ggplot for a model object |
| 2 | +#' |
| 3 | +#' This method provides a good visualization method for model results. |
| 4 | +#' Currently, only methods for glmnet models are implemented. |
| 5 | +#' |
| 6 | +#' @param object A model fit object. |
| 7 | +#' @param min_penalty A single, non-negative number for the smallest penalty |
| 8 | +#' value that should be shown in the plot. If left `NULL`, the whole data |
| 9 | +#' range is used. |
| 10 | +#' @param best_penalty A single, non-negative number that will show a vertical |
| 11 | +#' line marker. If left `NULL`, no line is shown. When this argument is used, |
| 12 | +#' the \pkg{ggrepl} package is required. |
| 13 | +#' @param top_n A non-negative integer for how many model predictors to label. |
| 14 | +#' The top predictors are ranked by their absolute coefficient value. For |
| 15 | +#' multinomial or multivariate models, the `top_n` terms are selected within |
| 16 | +#' class or response, respectively. |
| 17 | +#' @param ... For [autoplot.glmnet()], options to pass to |
| 18 | +#' [ggrepel::geom_label_repel()]. Otherwise, this argument is ignored. |
| 19 | +#' @return A ggplot object with penalty on the x-axis and coefficients on the |
| 20 | +#' y-axis. For multinomial or multivariate models, the plot is faceted. |
| 21 | +#' @details The \pkg{glmnet} package will need to be attached or loaded for |
| 22 | +#' its `autoplot()` method to work correctly. |
| 23 | +#' |
| 24 | +# registered in zzz.R |
| 25 | +autoplot.model_fit <- function(object, ...) { |
| 26 | + autoplot(object$fit, ...) |
| 27 | +} |
| 28 | + |
| 29 | +# glmnet is not a formal dependency here. |
| 30 | +# unit tests are located at https://github.com/tidymodels/extratests |
| 31 | +# nocov start |
| 32 | + |
| 33 | +# registered in zzz.R |
| 34 | +#' @rdname autoplot.model_fit |
| 35 | +autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL, |
| 36 | + top_n = 3L) { |
| 37 | + autoplot_glmnet(object, min_penalty, best_penalty, top_n, ...) |
| 38 | +} |
| 39 | + |
| 40 | + |
| 41 | +map_glmnet_coefs <- function(x) { |
| 42 | + coefs <- coef(x) |
| 43 | + # If parsnip is used to fit the model, glmnet should be attached and this will |
| 44 | + # work. If an object is loaded from a new session, they will need to load the |
| 45 | + # package. |
| 46 | + if (is.null(coefs)) { |
| 47 | + rlang::abort("Please load the glmnet package before running `autoplot()`.") |
| 48 | + } |
| 49 | + p <- x$dim[1] |
| 50 | + if (is.list(coefs)) { |
| 51 | + classes <- names(coefs) |
| 52 | + coefs <- purrr::map(coefs, reformat_coefs, p = p, penalty = x$lambda) |
| 53 | + coefs <- purrr::map2_dfr(coefs, classes, ~ dplyr::mutate(.x, class = .y)) |
| 54 | + } else { |
| 55 | + coefs <- reformat_coefs(coefs, p = p, penalty = x$lambda) |
| 56 | + } |
| 57 | + coefs |
| 58 | +} |
| 59 | + |
| 60 | +reformat_coefs <- function(x, p, penalty) { |
| 61 | + x <- as.matrix(x) |
| 62 | + num_estimates <- nrow(x) |
| 63 | + if (num_estimates > p) { |
| 64 | + # The intercept is first |
| 65 | + x <- x[-(num_estimates - p),, drop = FALSE] |
| 66 | + } |
| 67 | + term_lab <- rownames(x) |
| 68 | + colnames(x) <- paste(seq_along(penalty)) |
| 69 | + x <- tibble::as_tibble(x) |
| 70 | + x$term <- term_lab |
| 71 | + x <- tidyr::pivot_longer(x, cols = -term, names_to = "index", values_to = "estimate") |
| 72 | + x$penalty <- rep(penalty, p) |
| 73 | + x$index <- NULL |
| 74 | + x |
| 75 | +} |
| 76 | + |
| 77 | +top_coefs <- function(x, top_n = 5) { |
| 78 | + x %>% |
| 79 | + dplyr::group_by(term) %>% |
| 80 | + dplyr::arrange(term, dplyr::desc(abs(estimate))) %>% |
| 81 | + dplyr::slice(1) %>% |
| 82 | + dplyr::ungroup() %>% |
| 83 | + dplyr::arrange(dplyr::desc(abs(estimate))) %>% |
| 84 | + dplyr::slice(1:top_n) |
| 85 | +} |
| 86 | + |
| 87 | +autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) { |
| 88 | + check_penalty_value(min_penalty) |
| 89 | + |
| 90 | + tidy_coefs <- |
| 91 | + map_glmnet_coefs(x) %>% |
| 92 | + dplyr::filter(penalty >= min_penalty) |
| 93 | + |
| 94 | + actual_min_penalty <- min(tidy_coefs$penalty) |
| 95 | + num_terms <- length(unique(tidy_coefs$term)) |
| 96 | + top_n <- min(top_n[1], num_terms) |
| 97 | + if (top_n < 0) { |
| 98 | + top_n <- 0 |
| 99 | + } |
| 100 | + |
| 101 | + has_groups <- any(names(tidy_coefs) == "class") |
| 102 | + |
| 103 | + # Keep the large values |
| 104 | + if (has_groups) { |
| 105 | + label_coefs <- |
| 106 | + tidy_coefs %>% |
| 107 | + dplyr::group_nest(class) %>% |
| 108 | + dplyr::mutate(data = purrr::map(data, top_coefs, top_n = top_n)) %>% |
| 109 | + dplyr::select(class, data) %>% |
| 110 | + tidyr::unnest(cols = data) |
| 111 | + } else { |
| 112 | + if (is.null(best_penalty)) { |
| 113 | + label_coefs <- tidy_coefs %>% |
| 114 | + top_coefs(top_n) |
| 115 | + } else { |
| 116 | + label_coefs <- tidy_coefs %>% |
| 117 | + dplyr::filter(penalty > best_penalty) %>% |
| 118 | + dplyr::filter(penalty == min(penalty)) %>% |
| 119 | + dplyr::arrange(dplyr::desc(abs(estimate))) %>% |
| 120 | + dplyr::slice(seq_len(top_n)) |
| 121 | + } |
| 122 | + } |
| 123 | + |
| 124 | + label_coefs <- |
| 125 | + label_coefs %>% |
| 126 | + dplyr::mutate(penalty = best_penalty %||% actual_min_penalty) %>% |
| 127 | + dplyr::mutate(label = gsub(".pred_no_", "", term)) |
| 128 | + |
| 129 | + # plot the paths and highlight the large values |
| 130 | + p <- |
| 131 | + tidy_coefs %>% |
| 132 | + ggplot2::ggplot(ggplot2::aes(x = penalty, y = estimate, group = term, col = term)) |
| 133 | + |
| 134 | + if (has_groups) { |
| 135 | + p <- p + ggplot2::facet_wrap(~ class) |
| 136 | + } |
| 137 | + |
| 138 | + if (!is.null(best_penalty)) { |
| 139 | + check_penalty_value(best_penalty) |
| 140 | + p <- p + ggplot2::geom_vline(xintercept = best_penalty, lty = 3) |
| 141 | + } |
| 142 | + |
| 143 | + p <- p + |
| 144 | + ggplot2::geom_line(alpha = .4, show.legend = FALSE) + |
| 145 | + ggplot2::scale_x_log10() |
| 146 | + |
| 147 | + if(top_n > 0) { |
| 148 | + rlang::check_installed("ggrepel") |
| 149 | + p <- p + |
| 150 | + ggrepel::geom_label_repel( |
| 151 | + data = label_coefs, |
| 152 | + ggplot2::aes(y = estimate, label = label), |
| 153 | + show.legend = FALSE, |
| 154 | + ... |
| 155 | + ) |
| 156 | + } |
| 157 | + p |
| 158 | +} |
| 159 | + |
| 160 | +check_penalty_value <- function(x) { |
| 161 | + cl <- match.call() |
| 162 | + arg_val <- as.character(cl$x) |
| 163 | + if (!is.vector(x) || length(x) != 1 || !is.numeric(x) || x < 0) { |
| 164 | + msg <- paste0("Argument '", arg_val, "' should be a single, non-negative value.") |
| 165 | + rlang::abort(msg) |
| 166 | + } |
| 167 | + invisible(x) |
| 168 | +} |
| 169 | + |
| 170 | +# nocov end |
0 commit comments