Skip to content

Commit 301cf72

Browse files
authored
Make fit_xy() fail when mode is unknown (#801)
* Make fit_xy() fail when mode is unknown * remove check_mode() * revert Rproj option
1 parent f7999c0 commit 301cf72

File tree

5 files changed

+8
-16
lines changed

5 files changed

+8
-16
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
* Fixed installation failures due to undocumented knitr installation dependency (#785).
44

5+
* `fit_xy()` now fails when the model mode is unknown.
6+
57
# parsnip 1.0.1
68

79
* Enabled passing additional engine arguments with the xgboost `boost_tree()` engine. To supply engine-specific arguments that are documented in `xgboost::xgb.train()` as arguments to be passed via `params`, supply the list elements directly as named arguments to `set_engine()`. Read more in `?details_boost_tree_xgboost` (#787).

R/fit.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@ fit_xy.model_spec <-
228228
control = control_parsnip(),
229229
...
230230
) {
231+
if (object$mode == "unknown") {
232+
rlang::abort("Please set the mode in the model specification.")
233+
}
234+
231235
if (object$mode == "censored regression") {
232236
rlang::abort("Models for censored regression must use the formula interface.")
233237
}
@@ -244,7 +248,6 @@ fit_xy.model_spec <-
244248
}
245249
check_case_weights(case_weights, object)
246250

247-
object <- check_mode(object, levels(y))
248251
dots <- quos(...)
249252
if (is.null(object$engine)) {
250253
eng_vals <- possible_engines(object)

R/fit_helpers.R

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ form_form <-
1212

1313
# prob rewrite this as simple subset/levels
1414
y_levels <- levels_from_formula(env$formula, env$data)
15-
object <- check_mode(object, y_levels)
1615

1716
# if descriptors are needed, update descr_env with the calculated values
1817
if (requires_descrs(object)) {
@@ -64,7 +63,6 @@ xy_xy <- function(object, env, control, target = "none", ...) {
6463
if (inherits(env$x, "tbl_spark") | inherits(env$y, "tbl_spark"))
6564
rlang::abort("spark objects can only be used with the formula interface to `fit()`")
6665

67-
object <- check_mode(object, levels(env$y))
6866
check_outcome(env$y, object)
6967

7068
encoding_info <-

R/translate.R

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,6 @@ translate.default <- function(x, engine = x$engine, ...) {
9797
x
9898
}
9999

100-
check_mode <- function(object, lvl) {
101-
if (object$mode == "unknown") {
102-
if (!is.null(lvl)) {
103-
object$mode <- "classification"
104-
} else {
105-
object$mode <- "regression"
106-
}
107-
}
108-
object
109-
}
110-
111100
# ------------------------------------------------------------------------------
112101
# new code for revised model data structures
113102

tests/testthat/test_fit_interfaces.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ test_that('unknown modes', {
6666
)
6767
expect_error(
6868
fit_xy(mars_spec, x = mtcars[, -1], y = mtcars[,1]),
69-
regexp = NA
69+
regexp = "Please set the mode in the model specification."
7070
)
7171
expect_error(
7272
fit_xy(mars_spec, x = lending_club[,1:2], y = lending_club$Class),
73-
regexp = NA
73+
regexp = "Please set the mode in the model specification."
7474
)
7575
})
7676

0 commit comments

Comments
 (0)