|
| 1 | +#' Repair a model call object |
| 2 | +#' |
| 3 | +#' When the user passes a formula to `fit()` _and_ the underyling model function |
| 4 | +#' uses a formula, the call object produced by `fit()` may not be usable by |
| 5 | +#' other functions. For example, some arguments may still be quosures and the |
| 6 | +#' `data` portion of the call will not correspond to the original data. |
| 7 | +#' |
| 8 | +#' `repair_call()` call can adjust the model objects call to be usable by other |
| 9 | +#' functions and methods. |
| 10 | +#' @param x A fitted `parsnip` model. An error will occur if the underlying model |
| 11 | +#' does not have a `call` element. |
| 12 | +#' @param data A data object that is relavant to the call. In most cases, this |
| 13 | +#' is the data frame that was given to `parsnip` for the model fit (i.e., the |
| 14 | +#' training set data). The name of this data object is inserted into the call. |
| 15 | +#' @return A modified `parsnip` fitted model. |
| 16 | +#' @examples |
| 17 | +#' |
| 18 | +#' fitted_model <- |
| 19 | +#' linear_reg() %>% |
| 20 | +#' set_engine("lm", model = TRUE) %>% |
| 21 | +#' fit(mpg ~ ., data = mtcars) |
| 22 | +#' |
| 23 | +#' # In this call, note that `data` is not `mtcars` and the `model = ~TRUE` |
| 24 | +#' # indicates that the `model` argument is an `rlang` quosure. |
| 25 | +#' fitted_model$fit$call |
| 26 | +#' |
| 27 | +#' # All better: |
| 28 | +#' repair_call(fitted_model, mtcars)$fit$call |
| 29 | +#' @export |
| 30 | +repair_call <- function(x, data) { |
| 31 | + cl <- match.call() |
| 32 | + if (!any(names(x$fit) == "call")) { |
| 33 | + rlang::abort("No `call` object to modify.") |
| 34 | + } |
| 35 | + if (rlang::is_missing(data)) { |
| 36 | + rlang::abort("Please supply a data object to `data`.") |
| 37 | + } |
| 38 | + fit_call <- x$fit$call |
| 39 | + needs_eval <- purrr::map_lgl(fit_call, rlang::is_quosure) |
| 40 | + if (any(needs_eval)) { |
| 41 | + eval_args <- names(needs_eval)[needs_eval] |
| 42 | + for(arg in eval_args) { |
| 43 | + fit_call[[arg]] <- rlang::eval_tidy(fit_call[[arg]]) |
| 44 | + } |
| 45 | + } |
| 46 | + if (any(names(fit_call) == "data")) { |
| 47 | + fit_call$data <- cl$data |
| 48 | + } |
| 49 | + |
| 50 | + x$fit$call <- fit_call |
| 51 | + x |
| 52 | +} |
| 53 | + |
| 54 | +mod2_fit$fit$call |
| 55 | +repair_call(mod2_fit, mtcars)$fit$call |
| 56 | + |
| 57 | + |
0 commit comments