Skip to content

Commit 91fb33e

Browse files
committed
better engine and mode checking code
1 parent f522f8f commit 91fb33e

File tree

7 files changed

+129
-43
lines changed

7 files changed

+129
-43
lines changed

NEWS.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
# parsnip (development version)
22

3+
## Model Specification Changes
4+
5+
* A model function (`gen_additive_mod()`) was added for generalized additive models.
6+
37
* Each model now has a default engine that is used when the model is defined. The default for each model is listed in the help documents. This also adds functionality to declare an engine in the model specification function. `set_engine()` is still required if engine-specific arguments need to be added. (#513)
48

9+
* parsnip now checks for a valid engine for a given mode (#529)
10+
511
* The default engine for `multinom_reg()` was changed to `nnet`.
612

13+
## Other Changes
14+
715
* The helper functions `.convert_form_to_xy_fit()`, `.convert_form_to_xy_new()`, `.convert_xy_to_form_fit()`, and `.convert_xy_to_form_new()` for converting between formula and matrix interface are now exported for developer use (#508).
816

917
* Fix bug in `augment()` when non-predictor, non-outcome variables are included in data (#510).
1018

11-
* A model function (`gen_additive_mod()`) was added for generalized additive models.
1219

1320
# parsnip 0.1.6
1421

R/aaa_models.R

Lines changed: 77 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Initialize model environments
22

3+
all_modes <- c("classification", "regression", "censored regression")
4+
35
# ------------------------------------------------------------------------------
46

57
## Rules about model-related information
@@ -134,25 +136,90 @@ check_mode_val <- function(mode) {
134136
}
135137

136138

137-
stop_incompatible_mode <- function(spec_modes) {
139+
stop_incompatible_mode <- function(spec_modes, eng) {
140+
if (is.null(eng)) {
141+
msg <- glue::glue(
142+
"Available modes are: ",
143+
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
144+
)
145+
} else {
146+
msg <- glue::glue(
147+
"Available modes for engine {eng} are: ",
148+
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
149+
)
150+
}
151+
152+
rlang::abort(msg)
153+
}
154+
155+
stop_incompatible_engine <- function(spec_engs, mode) {
138156
msg <- glue::glue(
139-
"Available modes are: ",
140-
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
157+
"Available engines for mode {mode} are: ",
158+
glue::glue_collapse(glue::glue("'{spec_engs}'"), sep = ", ")
141159
)
142160
rlang::abort(msg)
143161
}
144162

145-
# check if class and mode are compatible
146-
check_spec_mode_val <- function(cls, mode) {
147-
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
163+
# check if class and mode and engine are compatible
164+
check_spec_mode_engine_val <- function(cls, eng, mode) {
165+
all_modes <- c("unknown", all_modes)
166+
if (!(mode %in% all_modes)) {
167+
rlang::abort(paste0("'", mode, "' is not a known mode."))
168+
}
169+
170+
model_info <- rlang::env_get(get_model_env(), cls)
171+
172+
# Cases where the model definition is in parsnip but all of the engines
173+
# are contained in a different package
174+
if (nrow(model_info) == 0) {
175+
return(invisible(NULL))
176+
}
177+
178+
# ------------------------------------------------------------------------------
179+
# First check engine against any mode
180+
181+
spec_engs <- model_info$engine
182+
# engine is allowed to be NULL
183+
if (!is.null(eng) && !(eng %in% spec_engs)) {
184+
rlang::abort(
185+
paste0(
186+
"Engine '", eng, "' is not supported for `", cls, "()`. See ",
187+
"`show_engines('", cls, "')`."
188+
)
189+
)
190+
}
191+
192+
# ----------------------------------------------------------------------------
193+
# Check modes based on model and engine
194+
195+
spec_modes <- model_info$mode
196+
if (!is.null(eng)) {
197+
spec_modes <- spec_modes[model_info$engine == eng]
198+
}
199+
spec_modes <- unique(c("unknown", spec_modes))
200+
148201
if (is.null(mode) || length(mode) > 1) {
149-
stop_incompatible_mode(spec_modes)
202+
stop_incompatible_mode(spec_modes, eng)
150203
} else if (!(mode %in% spec_modes)) {
151-
stop_incompatible_mode(spec_modes)
204+
stop_incompatible_mode(spec_modes, eng)
152205
}
206+
207+
# ----------------------------------------------------------------------------
208+
# Check engine based on model and model
209+
210+
# How check for compatibility with the chosen mode (if any)
211+
if (!is.null(mode) && mode != "unknown") {
212+
spec_engs <- spec_engs[model_info$mode == mode]
213+
}
214+
spec_engs <- unique(spec_engs)
215+
if (!is.null(eng) && !(eng %in% spec_engs)) {
216+
stop_incompatible_engine(spec_engs, mode)
217+
}
218+
153219
invisible(NULL)
154220
}
155221

222+
156223
check_engine_val <- function(eng) {
157224
if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng))
158225
rlang::abort("Please supply a character string for an engine (e.g. `'lm'`).")
@@ -625,8 +692,7 @@ get_dependency <- function(model) {
625692
set_fit <- function(model, mode, eng, value) {
626693
check_model_exists(model)
627694
check_eng_val(eng)
628-
check_mode_val(mode)
629-
check_engine_val(eng)
695+
check_spec_mode_engine_val(model, eng, mode)
630696
check_fit_info(value)
631697

632698
current <- get_model_env()
@@ -692,8 +758,7 @@ get_fit <- function(model) {
692758
set_pred <- function(model, mode, eng, type, value) {
693759
check_model_exists(model)
694760
check_eng_val(eng)
695-
check_mode_val(mode)
696-
check_engine_val(eng)
761+
check_spec_mode_engine_val(model, eng, mode)
697762
check_pred_info(value, type)
698763

699764
current <- get_model_env()

R/arguments.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ set_mode <- function(object, mode) {
8181
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
8282
stop_incompatible_mode(spec_modes)
8383
}
84-
check_spec_mode_val(cls, mode)
84+
check_spec_mode_engine_val(cls, object$engine, mode)
8585
object$mode <- mode
8686
object
8787
}

R/engines.R

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,6 @@ possible_engines <- function(object, ...) {
1010
unique(engs$engine)
1111
}
1212

13-
stop_incompatible_engine <- function(avail_eng) {
14-
msg <- glue::glue(
15-
"Available engines are: ",
16-
glue::glue_collapse(glue::glue("'{avail_eng}'"), sep = ", ")
17-
)
18-
rlang::abort(msg)
19-
}
20-
21-
check_engine <- function(object) {
22-
avail_eng <- possible_engines(object)
23-
eng <- object$engine
24-
if (is.null(eng) || length(eng) > 1) {
25-
stop_incompatible_engine(avail_eng)
26-
} else if (!(eng %in% avail_eng)) {
27-
stop_incompatible_engine(avail_eng)
28-
}
29-
object
30-
}
31-
3213
# ------------------------------------------------------------------------------
3314

3415
shhhh <- function(x)
@@ -99,7 +80,7 @@ set_engine <- function(object, engine, ...) {
9980
stop_incompatible_engine(avail_eng)
10081
}
10182
object$engine <- engine
102-
object <- check_engine(object)
83+
check_spec_mode_engine_val(class(object)[1], object$engine, object$mode)
10384

10485
if (object$engine == "liquidSVM") {
10586
lifecycle::deprecate_soft(

R/misc.R

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ check_empty_ellipse <- function (...) {
2323
terms
2424
}
2525

26-
all_modes <- c("classification", "regression", "censored regression")
27-
28-
2926
deparserizer <- function(x, limit = options()$width - 10) {
3027
x <- deparse(x, width.cutoff = limit)
3128
x <- gsub("^ ", "", x)
@@ -192,7 +189,7 @@ update_dot_check <- function(...) {
192189
#' @rdname add_on_exports
193190
new_model_spec <- function(cls, args, eng_args, mode, method, engine) {
194191

195-
check_spec_mode_val(cls, mode)
192+
check_spec_mode_engine_val(cls, engine, mode)
196193

197194
out <- list(args = args, eng_args = eng_args,
198195
mode = mode, method = method, engine = engine)

R/translate.R

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,15 @@ translate.default <- function(x, engine = x$engine, ...) {
5959
mod_name <- specific_model(x)
6060

6161
x$engine <- engine
62-
x <- check_engine(x)
63-
6462
if (x$mode == "unknown") {
6563
rlang::abort("Model code depends on the mode; please specify one.")
6664
}
6765

68-
if (is.null(x$method))
66+
check_spec_mode_engine_val(class(x)[1], x$engine, x$mode)
67+
68+
if (is.null(x$method)) {
6969
x$method <- get_model_spec(mod_name, x$mode, engine)
70+
}
7071

7172
arg_key <- get_args(mod_name, engine)
7273

@@ -174,7 +175,7 @@ deharmonize <- function(args, key) {
174175

175176
add_methods <- function(x, engine) {
176177
x$engine <- engine
177-
x <- check_engine(x)
178+
check_spec_mode_engine_val(class(x)[1], x$engine, x$mode)
178179
x$method <- get_model_spec(specific_model(x), x$mode, x$engine)
179180
x
180181
}

tests/testthat/test_args_and_modes.R

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,41 @@ test_that('pipe engine', {
4949
test_that("can't set a mode that isn't allowed by the model spec", {
5050
expect_error(
5151
set_mode(linear_reg(), "classification"),
52-
"Available modes are:"
52+
"Available modes"
5353
)
5454
})
55+
56+
57+
58+
test_that("unavailable modes for an engine and vice-versa", {
59+
expect_error(
60+
decision_tree() %>%
61+
set_mode("regression") %>%
62+
set_engine("C5.0"),
63+
"Available modes for engine C5"
64+
)
65+
expect_error(
66+
decision_tree() %>%
67+
set_engine("C5.0") %>%
68+
set_mode("regression"),
69+
"Available modes for engine C5"
70+
)
71+
72+
expect_error(
73+
decision_tree(engine = NULL) %>%
74+
set_engine("C5.0") %>%
75+
set_mode("regression"),
76+
"Available modes for engine C5"
77+
)
78+
79+
expect_error(
80+
decision_tree(engine = NULL)%>%
81+
set_mode("regression") %>%
82+
set_engine("C5.0"),
83+
"Available modes for engine C5"
84+
)
85+
86+
})
87+
88+
89+

0 commit comments

Comments
 (0)