Skip to content

Commit fa59be1

Browse files
committed
This function along with others changes would solve #274
1 parent 3c95d2e commit fa59be1

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

R/repair_call.R

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#' Repair a model call object
2+
#'
3+
#' When the user passes a formula to `fit()` _and_ the underyling model function
4+
#' uses a formula, the call object produced by `fit()` may not be usable by
5+
#' other functions. For example, some arguments may still be quosures and the
6+
#' `data` portion of the call will not correspond to the original data.
7+
#'
8+
#' `repair_call()` call can adjust the model objects call to be usable by other
9+
#' functions and methods.
10+
#' @param x A fitted `parsnip` model. An error will occur if the underlying model
11+
#' does not have a `call` element.
12+
#' @param data A data object that is relavant to the call. In most cases, this
13+
#' is the data frame that was given to `parsnip` for the model fit (i.e., the
14+
#' training set data). The name of this data object is inserted into the call.
15+
#' @return A modified `parsnip` fitted model.
16+
#' @examples
17+
#'
18+
#' fitted_model <-
19+
#' linear_reg() %>%
20+
#' set_engine("lm", model = TRUE) %>%
21+
#' fit(mpg ~ ., data = mtcars)
22+
#'
23+
#' # In this call, note that `data` is not `mtcars` and the `model = ~TRUE`
24+
#' # indicates that the `model` argument is an `rlang` quosure.
25+
#' fitted_model$fit$call
26+
#'
27+
#' # All better:
28+
#' repair_call(fitted_model, mtcars)$fit$call
29+
#' @export
30+
repair_call <- function(x, data) {
31+
cl <- match.call()
32+
if (!any(names(x$fit) == "call")) {
33+
rlang::abort("No `call` object to modify.")
34+
}
35+
if (rlang::is_missing(data)) {
36+
rlang::abort("Please supply a data object to `data`.")
37+
}
38+
fit_call <- x$fit$call
39+
needs_eval <- purrr::map_lgl(fit_call, rlang::is_quosure)
40+
if (any(needs_eval)) {
41+
eval_args <- names(needs_eval)[needs_eval]
42+
for(arg in eval_args) {
43+
fit_call[[arg]] <- rlang::eval_tidy(fit_call[[arg]])
44+
}
45+
}
46+
if (any(names(fit_call) == "data")) {
47+
fit_call$data <- cl$data
48+
}
49+
50+
x$fit$call <- fit_call
51+
x
52+
}
53+
54+
mod2_fit$fit$call
55+
repair_call(mod2_fit, mtcars)$fit$call
56+
57+

0 commit comments

Comments
 (0)