Skip to content

Commit 523cd1b

Browse files
committed
more checks/tests for model and engine
1 parent 9a3cfdd commit 523cd1b

File tree

8 files changed

+53
-11
lines changed

8 files changed

+53
-11
lines changed

R/aaa_models.R

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ all_modes <- c("classification", "regression", "censored regression")
2525

2626
# ------------------------------------------------------------------------------
2727

28-
2928
parsnip <- rlang::new_environment()
3029
parsnip$models <- NULL
31-
parsnip$modes <- c("regression", "classification", "unknown")
30+
parsnip$modes <- c(all_modes, "unknown")
3231

3332
# ------------------------------------------------------------------------------
3433

@@ -160,6 +159,23 @@ stop_incompatible_engine <- function(spec_engs, mode) {
160159
rlang::abort(msg)
161160
}
162161

162+
stop_missing_engine <- function(cls) {
163+
info <-
164+
get_from_env(cls) %>%
165+
dplyr::group_by(mode) %>%
166+
dplyr::summarize(msg = paste0(unique(mode), " {",
167+
paste0(unique(engine), collapse = ", "),
168+
"}"),
169+
.groups = "drop")
170+
if (nrow(info) == 0) {
171+
rlang::abort(paste0("No known engines for `", cls, "()`."))
172+
}
173+
msg <- paste0(info$msg, collapse = ", ")
174+
msg <- paste("Missing engine. Possible mode/engine combinations are:", msg)
175+
rlang::abort(msg)
176+
}
177+
178+
163179
# check if class and mode and engine are compatible
164180
check_spec_mode_engine_val <- function(cls, eng, mode) {
165181
all_modes <- c("unknown", all_modes)
@@ -172,6 +188,7 @@ check_spec_mode_engine_val <- function(cls, eng, mode) {
172188
# Cases where the model definition is in parsnip but all of the engines
173189
# are contained in a different package
174190
if (nrow(model_info) == 0) {
191+
check_mode_with_no_engine(cls, mode)
175192
return(invisible(NULL))
176193
}
177194

@@ -219,6 +236,12 @@ check_spec_mode_engine_val <- function(cls, eng, mode) {
219236
invisible(NULL)
220237
}
221238

239+
check_mode_with_no_engine <- function(cls, mode) {
240+
spec_modes <- get_from_env(paste0(cls, "_modes"))
241+
if (!(mode %in% spec_modes)) {
242+
stop_incompatible_mode(spec_modes, cls)
243+
}
244+
}
222245

223246
check_engine_val <- function(eng) {
224247
if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng))

R/arguments.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ set_mode <- function(object, mode) {
7979
cls <- class(object)[1]
8080
if (rlang::is_missing(mode)) {
8181
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
82-
stop_incompatible_mode(spec_modes)
82+
stop_incompatible_mode(spec_modes, cls)
8383
}
8484
check_spec_mode_engine_val(cls, object$engine, mode)
8585
object$mode <- mode

R/engines.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,16 @@ load_libs <- function(x, quiet, attach = FALSE) {
7171
#' translate(mod, engine = "glmnet")
7272
#' @export
7373
set_engine <- function(object, engine, ...) {
74+
mod_type <- class(object)[1]
7475
if (!inherits(object, "model_spec")) {
7576
rlang::abort("`object` should have class 'model_spec'.")
7677
}
7778

7879
if (rlang::is_missing(engine)) {
79-
avail_eng <- possible_engines(object)
80-
stop_incompatible_engine(avail_eng)
80+
stop_missing_engine(mod_type)
8181
}
8282
object$engine <- engine
83-
check_spec_mode_engine_val(class(object)[1], object$engine, object$mode)
83+
check_spec_mode_engine_val(mod_type, object$engine, object$mode)
8484

8585
if (object$engine == "liquidSVM") {
8686
lifecycle::deprecate_soft(
@@ -90,7 +90,7 @@ set_engine <- function(object, engine, ...) {
9090
}
9191

9292
new_model_spec(
93-
cls = class(object)[1],
93+
cls = mod_type,
9494
args = object$args,
9595
eng_args = enquos(...),
9696
mode = object$mode,

man/details_gen_additive_mod_mgcv.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/extract-parsnip.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/boost_tree_C5.0.Rmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ defaults <-
1313
param <-
1414
boost_tree() %>%
1515
set_engine("C5.0") %>%
16-
set_mode("regression") %>%
16+
set_mode("classification") %>%
1717
tunable() %>%
1818
dplyr::select(-source, -component, -component_id, parsnip = name) %>%
1919
dplyr::mutate(

man/rmd/decision_tree_C5.0.Rmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ defaults <-
1313
param <-
1414
decision_tree() %>%
1515
set_engine("C5.0") %>%
16-
set_mode("regression") %>%
16+
set_mode("classification") %>%
1717
tunable() %>%
1818
dplyr::select(-source, -component, -component_id, parsnip = name) %>%
1919
dplyr::mutate(

tests/testthat/test_args_and_modes.R

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,25 @@ test_that("unavailable modes for an engine and vice-versa", {
8383
"Available modes for engine C5"
8484
)
8585

86+
expect_error(
87+
proportional_hazards() %>% set_mode("regression"),
88+
"Available modes for engine proportional_hazards"
89+
)
90+
91+
expect_error(
92+
linear_reg() %>% set_mode(),
93+
"Available modes for engine linear_reg"
94+
)
95+
96+
expect_error(
97+
linear_reg() %>% set_engine(),
98+
"Missing engine"
99+
)
100+
101+
expect_error(
102+
proportional_hazards() %>% set_engine(),
103+
"No known engines for"
104+
)
86105
})
87106

88107

0 commit comments

Comments
 (0)