Skip to content

Commit 22192e6

Browse files
committed
fit recipe interfaces plus some documentation
1 parent 34c0323 commit 22192e6

File tree

3 files changed

+134
-37
lines changed

3 files changed

+134
-37
lines changed

NAMESPACE

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ export(fit)
1010
export(rand_forest)
1111
export(varying)
1212
importFrom(purrr,map_lgl)
13+
importFrom(recipes,all_outcomes)
14+
importFrom(recipes,all_predictors)
15+
importFrom(recipes,juice)
16+
importFrom(recipes,prep)
1317
importFrom(rlang,enquo)
1418
importFrom(rlang,eval_tidy)
1519
importFrom(rlang,expr)
@@ -24,5 +28,5 @@ importFrom(rlang,na_lgl)
2428
importFrom(rlang,quos)
2529
importFrom(stats,as.formula)
2630
importFrom(stats,model.frame)
27-
importFrom(utils,capture.output)
31+
importFrom(stats,model.response)
2832
importFrom(utils,installed.packages)

R/fit.R

Lines changed: 93 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# - think about case weights in each instance below
33
# - try/catch all model fit evaluations
44
# - option to capture output/verboseness
5-
5+
# - devise a unit test plan that does not add pkg deps for each model
66

77

88
#' Fit a Model Specification to a Dataset
@@ -12,32 +12,41 @@
1212
#' routine.
1313
#'
1414
#' @param object An object of class `model_spec`
15+
#' @param x Either an R formula, a data frame of predictors, or a
16+
#' recipe object.
17+
#' @param engine A character string for the software that should
18+
#' be used to fit the model. This is highly dependent on the type
19+
#' of model (e.g. linear regression, random forest, etc.).
20+
#' @param ... Other options required to fit the model. If `x` is a
21+
#' formula or recipe, then the `data` argument should be passed
22+
#' here. For the "x/y" interface, the outcome data should be passed
23+
#' in with the argument `y`.
24+
#' @details `fit` substitutes the current arguments in the model
25+
#' specification into the computational engine's code, checks them
26+
#' for validity, then fits the model using the data and the
27+
#' engine-specific code. Different model functions have different
28+
#' interfaces (e.g. formula or `x`/`y`) and `fit` translates
29+
#' between the interface used when `fit` was invoked and the one
30+
#' required by the underlying model.
31+
#'
32+
#' When possible, `fit` attempts to avoid making copies of the
33+
#' data. For example, if the underlying model uses a formula and
34+
#' fit is invoked with a formula, the original data are references
35+
#' when the model is fit. However, if the underlying model uses
36+
#' something else, such as `x`/`y`, the formula is evaluated and
37+
#' the data are converted to the required format. In this case, any
38+
#' calls in the resulting model objects reference the temporary
39+
#' objects used to fit the model.
1540
#' @export
1641
#' @rdname fit
1742
fit <- function (object, ...)
1843
UseMethod("fit")
1944

2045
# The S3 part here is awful (for now I hope)
2146

22-
#' @importFrom utils capture.output
23-
# fit_formula <- function(object, formula, data, verboseness = 0, engine = "ranger") {
24-
# varying_param_check(object)
25-
#
26-
# # go between input methods
27-
#
28-
# # data checks based on method
29-
#
30-
# object <- finalize(object, engine = engine)
31-
# if(verboseness == 0) {
32-
# fit_obj <- eval(object$method$fit)
33-
# } else {
34-
# capture.output(fit_obj <- eval(object$method$fit))
35-
# }
36-
# fit_obj
37-
# }
38-
39-
47+
#' @return An object for the fitted model.
4048
#' @export
49+
#' @rdname fit
4150
fit.model_spec <- function(object, x, engine = object$engine, ...) {
4251
object$engine <- engine
4352
object <- check_engine(object)
@@ -68,21 +77,21 @@ fit.model_spec <- function(object, x, engine = object$engine, ...) {
6877
#' @importFrom stats as.formula
6978
fit_formula <- function(object, formula, engine = engine, ...) {
7079
opts <- quos(...)
71-
80+
7281
if(!any(names(opts) == "data"))
7382
stop("Please pass a data frame with the `data` argument.",
7483
call. = FALSE)
7584

7685
# TODO Should probably just load the namespace
7786
for(pkg in object$method$library)
7887
suppressPackageStartupMessages(library(pkg, character.only = TRUE))
79-
88+
8089
# Look up the model's interface (e.g. formula, recipes, etc)
8190
# and delagate to the connector functions (`formula_to_recipe` etc)
8291
if(object$method$interface == "formula") {
8392
fit_expr <- sub_arg_values(object$method$fit, opts["data"])
8493
fit_expr$formula <- as.formula(eval(formula))
85-
res <- rlang:::eval_tidy(fit_expr)
94+
res <- eval_tidy(fit_expr)
8695
} else {
8796
if(object$method$interface %in% c("data.frame", "matrix")) {
8897
res <- formula_to_xy(object = object, formula = formula, data = opts["data"])
@@ -96,8 +105,6 @@ fit_formula <- function(object, formula, engine = engine, ...) {
96105
}
97106

98107
fit_xy <- function(object, x, ...) {
99-
# Look up the model's interface (e.g. formula, recipes, etc)
100-
# and delagate to the connector functions (`xy_to_formula` etc)
101108
opts <- quos(...)
102109

103110
if(!any(names(opts) == "y"))
@@ -109,13 +116,13 @@ fit_xy <- function(object, x, ...) {
109116
suppressPackageStartupMessages(library(pkg, character.only = TRUE))
110117

111118
# Look up the model's interface (e.g. formula, recipes, etc)
112-
# and delagate to the connector functions (`xy_to_formula` etc)
119+
# and delegate to the connector functions (`xy_to_formula` etc)
113120
if(object$method$interface == "formula") {
114121
res <- xy_to_formula(object = object, x = x, y = opts["y"])
115122
} else {
116123
if(object$method$interface %in% c("data.frame", "matrix")) {
117124
fit_expr <- sub_arg_values(object$method$fit, opts["y"])
118-
res <- rlang:::eval_tidy(fit_expr)
125+
res <- eval_tidy(fit_expr)
119126
} else {
120127
stop("I don't know about the ",
121128
object$method$interface, " interface.",
@@ -126,8 +133,31 @@ fit_xy <- function(object, x, ...) {
126133
}
127134

128135
fit_recipe <- function(object, recipe, ...) {
136+
opts <- quos(...)
137+
138+
if(!any(names(opts) == "data"))
139+
stop("Please pass a data frame with the `data` argument.",
140+
call. = FALSE)
141+
142+
# TODO Should probably just load the namespace
143+
for(pkg in object$method$library)
144+
suppressPackageStartupMessages(library(pkg, character.only = TRUE))
145+
129146
# Look up the model's interface (e.g. formula, recipes, etc)
130-
# and delagate to the connector functions (`recipe_to_formula` etc)
147+
# and delegate to the connector functions (`recipe_to_formula` etc)
148+
if(object$method$interface == "formula") {
149+
res <- recipe_to_formula(object = object, recipe = recipe, data = opts["data"])
150+
} else {
151+
if(object$method$interface %in% c("data.frame", "matrix")) {
152+
res <- recipe_to_xy(object = object, recipe = recipe, data = opts["data"])
153+
} else {
154+
stop("I don't know about the ",
155+
object$method$interface, " interface.",
156+
call. = FALSE)
157+
}
158+
}
159+
res
160+
131161
}
132162

133163
###################################################################
@@ -139,25 +169,56 @@ formula_to_recipe <- function(object, formula, data) {
139169

140170
}
141171

142-
#' @importFrom stats model.frame
172+
#' @importFrom stats model.frame model.response
143173
formula_to_xy <- function(object, formula, data) {
144174
# TODO how do we fill in the other standard things here (subset, contrasts etc)?
145175
# TODO add a "matrix" option here and invoke model.matrix
146176

147-
# Q: avoid eval using ?rlang:::get_expr(data[["data"]])
177+
# Q: avoid eval using ?get_expr(data[["data"]])
148178
x <- stats::model.frame(formula, eval_tidy(data[["data"]]))
149179
y <- model.response(x, "numeric")
150180
eval_tidy(object$method$fit)
151181
}
152182

153183
###################################################################
154184

185+
#' @importFrom recipes prep juice all_predictors all_outcomes
155186
recipe_to_formula <- function(object, recipe, data) {
156-
187+
# TODO case weights
188+
recipe <-
189+
prep(recipe, training = eval_tidy(data[["data"]]), retain = TRUE)
190+
dat <- juice(recipe, all_predictors(), all_outcomes())
191+
dat <- as.data.frame(dat)
192+
193+
data_info <- summary(recipe)
194+
y_names <- data_info$variable[data_info$role == "outcome"]
195+
if (length(y_names) > 1)
196+
y_names <-
197+
paste0("cbind(", paste0(y_names, collapse = ","), ")")
198+
199+
fit_expr <- object$method$fit
200+
fit_expr$formula <- as.formula(paste0(y_names, "~."))
201+
fit_expr$data <- quote(dat)
202+
eval_tidy(fit_expr)
157203
}
158204

159205
recipe_to_xy <- function(object, recipe, data) {
160-
206+
# TODO case weights
207+
recipe <-
208+
prep(recipe, training = eval_tidy(data[["data"]]), retain = TRUE)
209+
210+
x <- juice(recipe, all_predictors())
211+
x <- as.data.frame(x)
212+
y <- juice(recipe, all_outcomes())
213+
if (ncol(y) > 1)
214+
y <- as.data.frame(y)
215+
else
216+
y <- y[[1]]
217+
218+
fit_expr <- object$method$fit
219+
fit_expr$x <- quote(x)
220+
fit_expr$y <- quote(y)
221+
eval_tidy(fit_expr)
161222
}
162223

163224
###################################################################
@@ -169,13 +230,9 @@ xy_to_formula <- function(object, x, y) {
169230
fit_expr <- object$method$fit
170231
fit_expr$formula <- as.formula(.y ~ .)
171232
fit_expr$data <- quote(x)
172-
rlang:::eval_tidy(fit_expr)
233+
eval_tidy(fit_expr)
173234
}
174235

175236
xy_to_recipe <- function(object, x, y) {
176237

177238
}
178-
179-
180-
181-

man/fit.Rd

Lines changed: 36 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)