11# The "fit_interface" is what was supplied to `fit` as defined by
22# `check_interface`. The "model interface" is what the underlying
3- # model uses. These functions go from one to another.
3+ # model uses. These functions go from one to another.
44
55# TODO return pp objects like terms or recipe
66
77# TODO protect engine = "spark" with non-spark data object
88
99fit_interface_matrix <- function (x , y , object , control , ... ) {
10+ if (object $ engine == " spark" )
11+ stop(" spark objects can only be used with the formula interface to `fit` " ,
12+ " with a spark data object." , call. = FALSE )
1013 switch (
1114 object $ method $ interface ,
1215 data.frame = matrix_to_data.frame(object , x , y , control , ... ),
1316 matrix = matrix_to_matrix(object , x , y , control , ... ),
1417 formula = matrix_to_formula(object , x , y , control , ... ),
15- stop(" I don't know about that model interface." , call. = FALSE )
18+ stop(" I don't know about model interface '" ,
19+ object $ method $ interface , " '." , call. = FALSE )
1620 )
1721}
1822
1923fit_interface_data.frame <- function (x , y , object , control , ... ) {
24+ if (object $ engine == " spark" )
25+ stop(" spark objects can only be used with the formula interface to `fit` " ,
26+ " with a spark data object." , call. = FALSE )
2027 switch (
2128 object $ method $ interface ,
2229 data.frame = data.frame_to_data.frame(object , x , y , control , ... ),
2330 matrix = data.frame_to_matrix(object , x , y , control , ... ),
2431 formula = data.frame_to_formula(object , x , y , control , ... ),
25- stop(" I don't know about that model interface." , call. = FALSE )
32+ stop(" I don't know about model interface '" ,
33+ object $ method $ interface , " '." , call. = FALSE )
2634 )
2735}
2836
@@ -32,25 +40,27 @@ fit_interface_formula <- function(formula, data, object, control, ...) {
3240 data.frame = formula_to_data.frame(object , formula , data , control , ... ),
3341 matrix = formula_to_matrix(object , formula , data , control , ... ),
3442 formula = formula_to_formula(object , formula , data , control , ... ),
35- stop(" I don't know about that model interface." , call. = FALSE )
43+ stop(" I don't know about model interface '" ,
44+ object $ method $ interface , " '." , call. = FALSE )
3645 )
3746}
3847
3948fit_interface_recipe <- function (recipe , data , object , control , ... ) {
40- if (inherits( datax , " tbl_spark " ) )
41- stop(" spark objects can only be used with the formula interface to `fit`" ,
42- call. = FALSE )
49+ if (object $ engine == " spark " )
50+ stop(" spark objects can only be used with the formula interface to `fit` " ,
51+ " with a spark data object. " , call. = FALSE )
4352 switch (
4453 object $ method $ interface ,
45- data.frame = I(),
46- formula = I(),
47- matrix = I(),
48- stop(" I don't know about that model interface." , call. = FALSE )
54+ data.frame = recipe_to_data.frame(object , recipe , data , control , ... ),
55+ formula = recipe_to_formula(object , recipe , data , control , ... ),
56+ matrix = recipe_to_matrix(object , recipe , data , control , ... ),
57+ stop(" I don't know about model interface '" ,
58+ object $ method $ interface , " '." , call. = FALSE )
4959 )
5060}
5161
5262# ##################################################################
53- # # starts with some x/y interface (either matrix or data frame)
63+ # # starts with some x/y interface (either matrix or data frame)
5464# # in `fit`
5565
5666# ' @importFrom dplyr bind_cols
@@ -60,16 +70,16 @@ xy_to_xy <- function(object, x, y, control, ...) {
6070 if (inherits(x , " tbl_spark" ) | inherits(y , " tbl_spark" ))
6171 stop(" spark objects can only be used with the formula interface to `fit`" ,
6272 call. = FALSE )
63-
73+
6474 object $ method $ fit_args [[" x" ]] <- quote(x )
6575 object $ method $ fit_args [[" y" ]] <- quote(y )
66-
76+
6777 fit_call <- make_call(
6878 fun = object $ method $ fit_name [" fun" ],
6979 ns = object $ method $ fit_name [" pkg" ],
7080 object $ method $ fit_args
7181 )
72-
82+
7383 eval_mod(
7484 fit_call ,
7585 capture = control $ verbosity == 0 ,
@@ -132,27 +142,25 @@ data.frame_to_formula <- function(object, x, y, control, ...) {
132142# ##################################################################
133143# # Start with formula interface in `fit`
134144
135- # ' @importFrom stats model.frame model.response terms as.formula
145+ # ' @importFrom stats model.frame model.response terms as.formula model.matrix
136146
137147formula_to_formula <-
138148 function (object , formula , data , control , ... ) {
139149 opts <- quos(... )
140-
150+
141151 fit_args <- object $ method $ fit_args
142- # handle unevaluated arguments
143- fit_args <- resolve_args(fit_args , env = current_env())
144-
145- if (! inherits(data , " tbl_spark" )) {
146- fit_args $ data <- data
147- } else {
152+
153+ if (isTRUE(unname(object $ method $ fit_name [" pkg" ] == " sparklyr" ))) {
148154 fit_args $ x <- data
155+ } else {
156+ fit_args $ data <- data
149157 }
150158 fit_args $ formula <- formula
151-
159+
152160 fit_call <- make_call(fun = object $ method $ fit_name [" fun" ],
153161 ns = object $ method $ fit_name [" pkg" ],
154162 fit_args )
155-
163+
156164 res <-
157165 eval_mod(
158166 fit_call ,
@@ -165,14 +173,17 @@ formula_to_formula <-
165173 }
166174
167175formula_to_data.frame <- function (object , formula , data , control , ... ) {
176+ if (is.name(data ))
177+ data <- eval_tidy(data , env = caller_env())
178+
168179 if (! is.data.frame(data ))
169180 data = as.data.frame(data )
170-
181+
171182 # TODO: how do we fill in the other standard things here (subset, contrasts etc)?
172-
183+
173184 x <- stats :: model.frame(eval(formula ), eval(data ))
174185 y <- model.response(x )
175-
186+
176187 # Remove outcome column(s) from `x`
177188 outcome_cols <- attr(terms(x ), " response" )
178189 if (! isTRUE(all.equal(outcome_cols , 0 ))) {
@@ -182,22 +193,25 @@ formula_to_data.frame <- function(object, formula, data, control, ...) {
182193}
183194
184195formula_to_matrix <- function (object , formula , data , control , ... ) {
196+ if (is.name(data ))
197+ data <- eval_tidy(data , env = caller_env())
198+
185199 if (! is.data.frame(data ))
186200 data = as.data.frame(data )
187-
201+
188202 # TODO: how do we fill in the other standard things here (subset, etc)?
189-
203+
190204 x <- stats :: model.frame(eval(formula ), eval(data ))
191205 trms <- attr(x , " terms" )
192206 y <- model.response(x )
193207 if (is.data.frame(y ))
194208 y <- as.matrix(y )
195-
209+
196210 # TODO sparse model matrices?
197211 x <- model.matrix(trms , data = x , contrasts.arg = getOption(" contrasts" ))
198212 # TODO Assume no intercept for now
199- x <- x [, ! (colnames(x ) %in% " (Intercept)" ), dtop = FALSE ]
200-
213+ x <- x [, ! (colnames(x ) %in% " (Intercept)" ), drop = FALSE ]
214+
201215 xy_to_xy(object , x , y , control , ... )
202216}
203217
@@ -209,7 +223,7 @@ formula_to_matrix <- function(object, formula, data, control, ...) {
209223recipe_data <- function (recipe , data , control , output = " matrix" , combine = FALSE ) {
210224 recipe <-
211225 prep(recipe , training = data , retain = TRUE , verbose = control $ verbosity > 1 )
212-
226+
213227 if (combine ) {
214228 out <- list (data = juice(recipe , all_predictors(), all_outcomes(), composition = output ))
215229 data_info <- summary(recipe )
@@ -225,23 +239,23 @@ recipe_data <- function(recipe, data, control, output = "matrix", combine = FALS
225239 y = juice(recipe , all_outcomes(), composition = output )
226240 )
227241 if (ncol(out $ y ) == 1 )
228- y <- y [[1 ]]
242+ out $ y <- out $ y [[1 ]]
229243 }
230244 out
231245}
232246
233247recipe_to_formula <-
234248 function (object , recipe , data , control , ... ) {
235- info <- recipe_data(recipe , data , control , output = " data.frame " , combine = TRUE )
236- formula_to_formula(object , dat $ form , dat $ data , control , ... )
249+ info <- recipe_data(recipe , data , control , output = " tibble " , combine = TRUE )
250+ formula_to_formula(object , info $ form , info $ data , control , ... )
237251 }
238252
239253recipe_to_data.frame <- function (object , recipe , data , control , ... ) {
240- dat <- recipe_data(recipe , data , control , output = " data.frame " , combine = FALSE )
241- xy_to_xy(object , dat $ x , dat $ y , control , ... )
254+ info <- recipe_data(recipe , data , control , output = " tibble " , combine = FALSE )
255+ xy_to_xy(object , info $ x , info $ y , control , ... )
242256}
243257
244258recipe_to_matrix <- function (object , recipe , data , control , ... ) {
245- dat <- recipe_data(recipe , data , control , output = " matrix" , combine = FALSE )
246- xy_to_xy(object , dat $ x , dat $ y , control , ... )
259+ info <- recipe_data(recipe , data , control , output = " matrix" , combine = FALSE )
260+ xy_to_xy(object , info $ x , info $ y , control , ... )
247261}
0 commit comments