Skip to content

Commit 1f07131

Browse files
committed
more error traps for spark
1 parent ff60618 commit 1f07131

File tree

1 file changed

+55
-41
lines changed

1 file changed

+55
-41
lines changed

R/fitter.R

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,36 @@
11
# The "fit_interface" is what was supplied to `fit` as defined by
22
# `check_interface`. The "model interface" is what the underlying
3-
# model uses. These functions go from one to another.
3+
# model uses. These functions go from one to another.
44

55
# TODO return pp objects like terms or recipe
66

77
# TODO protect engine = "spark" with non-spark data object
88

99
fit_interface_matrix <- function(x, y, object, control, ...) {
10+
if (object$engine == "spark")
11+
stop("spark objects can only be used with the formula interface to `fit` ",
12+
"with a spark data object.", call. = FALSE)
1013
switch(
1114
object$method$interface,
1215
data.frame = matrix_to_data.frame(object, x, y, control, ...),
1316
matrix = matrix_to_matrix(object, x, y, control, ...),
1417
formula = matrix_to_formula(object, x, y, control, ...),
15-
stop("I don't know about that model interface.", call. = FALSE)
18+
stop("I don't know about model interface '",
19+
object$method$interface, "'.", call. = FALSE)
1620
)
1721
}
1822

1923
fit_interface_data.frame <- function(x, y, object, control, ...) {
24+
if (object$engine == "spark")
25+
stop("spark objects can only be used with the formula interface to `fit` ",
26+
"with a spark data object.", call. = FALSE)
2027
switch(
2128
object$method$interface,
2229
data.frame = data.frame_to_data.frame(object, x, y, control, ...),
2330
matrix = data.frame_to_matrix(object, x, y, control, ...),
2431
formula = data.frame_to_formula(object, x, y, control, ...),
25-
stop("I don't know about that model interface.", call. = FALSE)
32+
stop("I don't know about model interface '",
33+
object$method$interface, "'.", call. = FALSE)
2634
)
2735
}
2836

@@ -32,25 +40,27 @@ fit_interface_formula <- function(formula, data, object, control, ...) {
3240
data.frame = formula_to_data.frame(object, formula, data, control, ...),
3341
matrix = formula_to_matrix(object, formula, data, control, ...),
3442
formula = formula_to_formula(object, formula, data, control, ...),
35-
stop("I don't know about that model interface.", call. = FALSE)
43+
stop("I don't know about model interface '",
44+
object$method$interface, "'.", call. = FALSE)
3645
)
3746
}
3847

3948
fit_interface_recipe <- function(recipe, data, object, control, ...) {
40-
if (inherits(datax, "tbl_spark"))
41-
stop("spark objects can only be used with the formula interface to `fit`",
42-
call. = FALSE)
49+
if (object$engine == "spark")
50+
stop("spark objects can only be used with the formula interface to `fit` ",
51+
"with a spark data object.", call. = FALSE)
4352
switch(
4453
object$method$interface,
45-
data.frame = I(),
46-
formula = I(),
47-
matrix = I(),
48-
stop("I don't know about that model interface.", call. = FALSE)
54+
data.frame = recipe_to_data.frame(object, recipe, data, control, ...),
55+
formula = recipe_to_formula(object, recipe, data, control, ...),
56+
matrix = recipe_to_matrix(object, recipe, data, control, ...),
57+
stop("I don't know about model interface '",
58+
object$method$interface, "'.", call. = FALSE)
4959
)
5060
}
5161

5262
###################################################################
53-
## starts with some x/y interface (either matrix or data frame)
63+
## starts with some x/y interface (either matrix or data frame)
5464
## in `fit`
5565

5666
#' @importFrom dplyr bind_cols
@@ -60,16 +70,16 @@ xy_to_xy <- function(object, x, y, control, ...) {
6070
if (inherits(x, "tbl_spark") | inherits(y, "tbl_spark"))
6171
stop("spark objects can only be used with the formula interface to `fit`",
6272
call. = FALSE)
63-
73+
6474
object$method$fit_args[["x"]] <- quote(x)
6575
object$method$fit_args[["y"]] <- quote(y)
66-
76+
6777
fit_call <- make_call(
6878
fun = object$method$fit_name["fun"],
6979
ns = object$method$fit_name["pkg"],
7080
object$method$fit_args
7181
)
72-
82+
7383
eval_mod(
7484
fit_call,
7585
capture = control$verbosity == 0,
@@ -132,27 +142,25 @@ data.frame_to_formula <- function(object, x, y, control, ...) {
132142
###################################################################
133143
## Start with formula interface in `fit`
134144

135-
#' @importFrom stats model.frame model.response terms as.formula
145+
#' @importFrom stats model.frame model.response terms as.formula model.matrix
136146

137147
formula_to_formula <-
138148
function(object, formula, data, control, ...) {
139149
opts <- quos(...)
140-
150+
141151
fit_args <- object$method$fit_args
142-
# handle unevaluated arguments
143-
fit_args <- resolve_args(fit_args, env = current_env())
144-
145-
if (!inherits(data, "tbl_spark")) {
146-
fit_args$data <- data
147-
} else {
152+
153+
if (isTRUE(unname(object$method$fit_name["pkg"] == "sparklyr"))) {
148154
fit_args$x <- data
155+
} else {
156+
fit_args$data <- data
149157
}
150158
fit_args$formula <- formula
151-
159+
152160
fit_call <- make_call(fun = object$method$fit_name["fun"],
153161
ns = object$method$fit_name["pkg"],
154162
fit_args)
155-
163+
156164
res <-
157165
eval_mod(
158166
fit_call,
@@ -165,14 +173,17 @@ formula_to_formula <-
165173
}
166174

167175
formula_to_data.frame <- function(object, formula, data, control, ...) {
176+
if (is.name(data))
177+
data <- eval_tidy(data, env = caller_env())
178+
168179
if (!is.data.frame(data))
169180
data = as.data.frame(data)
170-
181+
171182
# TODO: how do we fill in the other standard things here (subset, contrasts etc)?
172-
183+
173184
x <- stats::model.frame(eval(formula), eval(data))
174185
y <- model.response(x)
175-
186+
176187
# Remove outcome column(s) from `x`
177188
outcome_cols <- attr(terms(x), "response")
178189
if (!isTRUE(all.equal(outcome_cols, 0))) {
@@ -182,22 +193,25 @@ formula_to_data.frame <- function(object, formula, data, control, ...) {
182193
}
183194

184195
formula_to_matrix <- function(object, formula, data, control, ...) {
196+
if (is.name(data))
197+
data <- eval_tidy(data, env = caller_env())
198+
185199
if (!is.data.frame(data))
186200
data = as.data.frame(data)
187-
201+
188202
# TODO: how do we fill in the other standard things here (subset, etc)?
189-
203+
190204
x <- stats::model.frame(eval(formula), eval(data))
191205
trms <- attr(x, "terms")
192206
y <- model.response(x)
193207
if (is.data.frame(y))
194208
y <- as.matrix(y)
195-
209+
196210
# TODO sparse model matrices?
197211
x <- model.matrix(trms, data = x, contrasts.arg = getOption("contrasts"))
198212
# TODO Assume no intercept for now
199-
x <- x[, !(colnames(x) %in% "(Intercept)"), dtop = FALSE]
200-
213+
x <- x[, !(colnames(x) %in% "(Intercept)"), drop = FALSE]
214+
201215
xy_to_xy(object, x, y, control, ...)
202216
}
203217

@@ -209,7 +223,7 @@ formula_to_matrix <- function(object, formula, data, control, ...) {
209223
recipe_data <- function(recipe, data, control, output = "matrix", combine = FALSE) {
210224
recipe <-
211225
prep(recipe, training = data, retain = TRUE, verbose = control$verbosity > 1)
212-
226+
213227
if (combine) {
214228
out <- list(data = juice(recipe, all_predictors(), all_outcomes(), composition = output))
215229
data_info <- summary(recipe)
@@ -225,23 +239,23 @@ recipe_data <- function(recipe, data, control, output = "matrix", combine = FALS
225239
y = juice(recipe, all_outcomes(), composition = output)
226240
)
227241
if (ncol(out$y) == 1)
228-
y <- y[[1]]
242+
out$y <- out$y[[1]]
229243
}
230244
out
231245
}
232246

233247
recipe_to_formula <-
234248
function(object, recipe, data, control, ...) {
235-
info <- recipe_data(recipe, data, control, output = "data.frame", combine = TRUE)
236-
formula_to_formula(object, dat$form, dat$data, control, ...)
249+
info <- recipe_data(recipe, data, control, output = "tibble", combine = TRUE)
250+
formula_to_formula(object, info$form, info$data, control, ...)
237251
}
238252

239253
recipe_to_data.frame <- function(object, recipe, data, control, ...) {
240-
dat <- recipe_data(recipe, data, control, output = "data.frame", combine = FALSE)
241-
xy_to_xy(object, dat$x, dat$y, control, ...)
254+
info <- recipe_data(recipe, data, control, output = "tibble", combine = FALSE)
255+
xy_to_xy(object, info$x, info$y, control, ...)
242256
}
243257

244258
recipe_to_matrix <- function(object, recipe, data, control, ...) {
245-
dat <- recipe_data(recipe, data, control, output = "matrix", combine = FALSE)
246-
xy_to_xy(object, dat$x, dat$y, control, ...)
259+
info <- recipe_data(recipe, data, control, output = "matrix", combine = FALSE)
260+
xy_to_xy(object, info$x, info$y, control, ...)
247261
}

0 commit comments

Comments
 (0)