Skip to content

Commit 3a3c134

Browse files
committed
Merge IT ALL
Merge branch 'master' into encoding-options # Conflicts: # R/linear_reg_data.R # R/svm_poly_data.R # R/svm_rbf_data.R # tests/testthat/test_svm_poly.R # tests/testthat/test_svm_rbf.R
2 parents a420749 + bf507af commit 3a3c134

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+366
-870
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ export(predict_quantile.model_fit)
143143
export(predict_raw)
144144
export(predict_raw.model_fit)
145145
export(rand_forest)
146+
export(repair_call)
146147
export(rpart_train)
147148
export(set_args)
148149
export(set_dependency)

NEWS.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,20 @@
44

55
* `tidyr` >= 1.0.0 is now required.
66

7+
* SVM models produced by `kernlab` now use the formula method. This change was due to how `ksvm()` made indicator variables for factor predictors (with one-hot encodings). Since the ordinary formula method did not do this, the data are passed as-is to `ksvm()` so that the results are closer to what one would get if `ksmv()` were called directly.
8+
9+
* MARS models produced by `earth` now use the formula method.
10+
11+
* Under-the-hood changes were made so that non-standard data arguments in the modeling packages can be accomodated. (#315)
12+
713
## New Features
814

915
* A new main argument was added to `boost_tree()` called `stop_iter` for early stopping. The `xgb_train()` function gained arguments for early stopping and a percentage of data to leave out for a validation set.
1016

17+
* If `fit()` is used and the underlying model uses a formula, the _actual_ formula is pass to the model (instead of a placeholder). This makes the model call better.
18+
19+
* A function named `repair_call()` was added. This can help change the underlying models `call` object to better reflect what they would have obtained if the model function had been used directly (instead of via `parsnip`). This is only useful when the user chooses a formula interface and the model uses a formula interface. It will also be of limited use when a recipes is used to construct the feature set in `workflows` or `tune`.
20+
1121
# parsnip 0.1.1
1222

1323
## New Features

R/aaa.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ utils::globalVariables(
6666
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
6767
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
6868
"max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees",
69-
"sub_neighbors", ".pred_class")
69+
"sub_neighbors", ".pred_class", "x", "y")
7070
)
7171

7272
# nocov end

R/aaa_models.R

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,35 @@ check_fit_info <- function(fit_obj) {
195195
if (is.null(fit_obj)) {
196196
rlang::abort("The `fit` module cannot be NULL.")
197197
}
198+
199+
# check required data elements
198200
exp_nms <- c("defaults", "func", "interface", "protect")
199-
if (!isTRUE(all.equal(sort(names(fit_obj)), exp_nms))) {
201+
has_req_nms <- exp_nms %in% names(fit_obj)
202+
203+
if (!all(has_req_nms)) {
200204
rlang::abort(
201205
glue::glue("The `fit` module should have elements: ",
202206
glue::glue_collapse(glue::glue("`{exp_nms}`"), sep = ", "))
203207
)
204208
}
205209

210+
# check optional data elements
211+
opt_nms <- c("data")
212+
other_nms <- setdiff(exp_nms, names(fit_obj))
213+
has_opt_nms <- other_nms %in% opt_nms
214+
if (any(!has_opt_nms)) {
215+
msg <- glue::glue("The `fit` module can only have optional elements: ",
216+
glue::glue_collapse(glue::glue("`{exp_nms}`"), sep = ", "))
217+
218+
rlang::abort(msg)
219+
}
220+
if (any(other_nms == "data")) {
221+
data_nms <- names(fit_obj$data)
222+
if (length(data_nms == 0) || any(data_nms == "")) {
223+
rlang::abort("All elements of the `data` argument vector must be named.")
224+
}
225+
}
226+
206227
check_interface_val(fit_obj$interface)
207228
check_func_val(fit_obj$func)
208229

R/arguments.R

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ check_eng_args <- function(args, obj, core_args) {
2424
if (length(common_args) > 0) {
2525
args <- args[!(names(args) %in% common_args)]
2626
common_args <- paste0(common_args, collapse = ", ")
27-
rlang::warn(glue::glue("The following arguments cannot be manually modified",
27+
rlang::warn(glue::glue("The following arguments cannot be manually modified ",
2828
"and were removed: {common_args}."))
2929
}
3030
args
@@ -113,3 +113,94 @@ eval_args <- function(spec, ...) {
113113
spec$eng_args <- purrr::map(spec$eng_args, maybe_eval)
114114
spec
115115
}
116+
117+
# ------------------------------------------------------------------------------
118+
119+
# In some cases, a model function that we are calling has non-standard argument
120+
# names. For example, a function foo() that only has the x/y interface might
121+
# have a signature like `foo(X, Y)`.
122+
123+
# To deal with this, we allow for the `data` element of the model
124+
# as an option to specify these actual argument names
125+
#
126+
# value = list(
127+
# interface = "xy",
128+
# data = c(x = "X", y = "Y"),
129+
# protect = c("X", "Y"),
130+
# func = c(pkg = "bar", fun = "foo"),
131+
# defaults = list()
132+
# )
133+
134+
make_call <- function(fun, ns, args, ...) {
135+
# remove any null or placeholders (`missing_args`) that remain
136+
discard <-
137+
vapply(args, function(x)
138+
is_missing_arg(x) | is.null(x), logical(1))
139+
args <- args[!discard]
140+
141+
if (!is.null(ns) & !is.na(ns)) {
142+
out <- call2(fun, !!!args, .ns = ns)
143+
} else
144+
out <- call2(fun, !!!args)
145+
out
146+
}
147+
148+
149+
make_form_call <- function(object, env = NULL) {
150+
fit_args <- object$method$fit$args
151+
152+
# Get the arguments related to data:
153+
if (is.null(object$method$fit$data)) {
154+
data_args <- c(formula = "formula", data = "data")
155+
} else {
156+
data_args <- object$method$fit$data
157+
}
158+
159+
# add data arguments
160+
for (i in seq_along(data_args)) {
161+
fit_args[[ unname(data_args[i]) ]] <- sym(names(data_args)[i])
162+
}
163+
164+
# sub in actual formula
165+
fit_args[[ unname(data_args["formula"]) ]] <- env$formula
166+
167+
if (object$engine == "spark") {
168+
env$x <- env$data
169+
}
170+
171+
fit_call <- make_call(
172+
fun = object$method$fit$func["fun"],
173+
ns = object$method$fit$func["pkg"],
174+
fit_args
175+
)
176+
fit_call
177+
}
178+
179+
make_xy_call <- function(object, target) {
180+
fit_args <- object$method$fit$args
181+
182+
# Get the arguments related to data:
183+
if (is.null(object$method$fit$data)) {
184+
data_args <- c(x = "x", y = "y")
185+
} else {
186+
data_args <- object$method$fit$data
187+
}
188+
189+
object$method$fit$args[[ unname(data_args["y"]) ]] <- rlang::expr(y)
190+
object$method$fit$args[[ unname(data_args["x"]) ]] <-
191+
switch(
192+
target,
193+
none = rlang::expr(x),
194+
data.frame = rlang::expr(as.data.frame(x)),
195+
matrix = rlang::expr(as.matrix(x)),
196+
rlang::abort(glue::glue("Invalid data type target: {target}."))
197+
)
198+
199+
fit_call <- make_call(
200+
fun = object$method$fit$func["fun"],
201+
ns = object$method$fit$func["pkg"],
202+
object$method$fit$args
203+
)
204+
205+
fit_call
206+
}

R/boost_tree_data.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ set_fit(
358358
mode = "regression",
359359
value = list(
360360
interface = "formula",
361+
data = c(formula = "formula", data = "x"),
361362
protect = c("x", "formula", "type"),
362363
func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"),
363364
defaults = list(seed = expr(sample.int(10 ^ 5, 1)))
@@ -377,6 +378,7 @@ set_fit(
377378
mode = "classification",
378379
value = list(
379380
interface = "formula",
381+
data = c(formula = "formula", data = "x"),
380382
protect = c("x", "formula", "type"),
381383
func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"),
382384
defaults = list(seed = expr(sample.int(10 ^ 5, 1)))

R/decision_tree.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ print.decision_tree <- function(x, ...) {
102102

103103
#' @export
104104
#' @inheritParams update.boost_tree
105-
#' @param object A random forest model specification.
105+
#' @param object A decision tree model specification.
106106
#' @examples
107107
#' model <- decision_tree(cost_complexity = 10, min_n = 3)
108108
#' model

R/decision_tree_data.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ set_fit(
258258
mode = "regression",
259259
value = list(
260260
interface = "formula",
261+
data = c(formula = "formula", data = "x"),
261262
protect = c("x", "formula"),
262263
func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"),
263264
defaults =
@@ -278,6 +279,7 @@ set_fit(
278279
mode = "classification",
279280
value = list(
280281
interface = "formula",
282+
data = c(formula = "formula", data = "x"),
281283
protect = c("x", "formula"),
282284
func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"),
283285
defaults =

R/fit_helpers.R

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,7 @@ form_form <-
3030
# sub in arguments to actual syntax for corresponding engine
3131
object <- translate(object, engine = object$engine)
3232

33-
fit_args <- object$method$fit$args
34-
35-
if (is_spark(object)) {
36-
fit_args$x <- quote(x)
37-
env$x <- env$data
38-
} else {
39-
fit_args$data <- quote(data)
40-
}
41-
fit_args$formula <- quote(formula)
42-
43-
fit_call <- make_call(
44-
fun = object$method$fit$func["fun"],
45-
ns = object$method$fit$func["pkg"],
46-
fit_args
47-
)
33+
fit_call <- make_form_call(object, env = env)
4834

4935
res <- list(
5036
lvl = y_levels,
@@ -89,21 +75,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
8975
# sub in arguments to actual syntax for corresponding engine
9076
object <- translate(object, engine = object$engine)
9177

92-
object$method$fit$args[["y"]] <- quote(y)
93-
object$method$fit$args[["x"]] <-
94-
switch(
95-
target,
96-
none = quote(x),
97-
data.frame = quote(as.data.frame(x)),
98-
matrix = quote(as.matrix(x)),
99-
rlang::abort(glue::glue("Invalid data type target: {target}."))
100-
)
101-
102-
fit_call <- make_call(
103-
fun = object$method$fit$func["fun"],
104-
ns = object$method$fit$func["pkg"],
105-
object$method$fit$args
106-
)
78+
fit_call <- make_xy_call(object, target)
10779

10880
res <- list(lvl = levels(env$y), spec = object)
10981

R/linear_reg_data.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ set_fit(
314314
mode = "regression",
315315
value = list(
316316
interface = "formula",
317+
data = c(formula = "formula", data = "x"),
317318
protect = c("x", "formula", "weight_col"),
318319
func = c(pkg = "sparklyr", fun = "ml_linear_regression"),
319320
defaults = list()

0 commit comments

Comments
 (0)