|
1 | 1 | # Initialize model environments |
2 | 2 |
|
| 3 | +all_modes <- c("classification", "regression", "censored regression") |
| 4 | + |
3 | 5 | # ------------------------------------------------------------------------------ |
4 | 6 |
|
5 | 7 | ## Rules about model-related information |
@@ -134,25 +136,90 @@ check_mode_val <- function(mode) { |
134 | 136 | } |
135 | 137 |
|
136 | 138 |
|
137 | | -stop_incompatible_mode <- function(spec_modes) { |
| 139 | +stop_incompatible_mode <- function(spec_modes, eng) { |
| 140 | + if (is.null(eng)) { |
| 141 | + msg <- glue::glue( |
| 142 | + "Available modes are: ", |
| 143 | + glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ") |
| 144 | + ) |
| 145 | + } else { |
| 146 | + msg <- glue::glue( |
| 147 | + "Available modes for engine {eng} are: ", |
| 148 | + glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ") |
| 149 | + ) |
| 150 | + } |
| 151 | + |
| 152 | + rlang::abort(msg) |
| 153 | +} |
| 154 | + |
| 155 | +stop_incompatible_engine <- function(spec_engs, mode) { |
138 | 156 | msg <- glue::glue( |
139 | | - "Available modes are: ", |
140 | | - glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ") |
| 157 | + "Available engines for mode {mode} are: ", |
| 158 | + glue::glue_collapse(glue::glue("'{spec_engs}'"), sep = ", ") |
141 | 159 | ) |
142 | 160 | rlang::abort(msg) |
143 | 161 | } |
144 | 162 |
|
145 | | -# check if class and mode are compatible |
146 | | -check_spec_mode_val <- function(cls, mode) { |
147 | | - spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes")) |
| 163 | +# check if class and mode and engine are compatible |
| 164 | +check_spec_mode_engine_val <- function(cls, eng, mode) { |
| 165 | + all_modes <- c("unknown", all_modes) |
| 166 | + if (!(mode %in% all_modes)) { |
| 167 | + rlang::abort(paste0("'", mode, "' is not a known mode.")) |
| 168 | + } |
| 169 | + |
| 170 | + model_info <- rlang::env_get(get_model_env(), cls) |
| 171 | + |
| 172 | + # Cases where the model definition is in parsnip but all of the engines |
| 173 | + # are contained in a different package |
| 174 | + if (nrow(model_info) == 0) { |
| 175 | + return(invisible(NULL)) |
| 176 | + } |
| 177 | + |
| 178 | + # ------------------------------------------------------------------------------ |
| 179 | + # First check engine against any mode |
| 180 | + |
| 181 | + spec_engs <- model_info$engine |
| 182 | + # engine is allowed to be NULL |
| 183 | + if (!is.null(eng) && !(eng %in% spec_engs)) { |
| 184 | + rlang::abort( |
| 185 | + paste0( |
| 186 | + "Engine '", eng, "' is not supported for `", cls, "()`. See ", |
| 187 | + "`show_engines('", cls, "')`." |
| 188 | + ) |
| 189 | + ) |
| 190 | + } |
| 191 | + |
| 192 | + # ---------------------------------------------------------------------------- |
| 193 | + # Check modes based on model and engine |
| 194 | + |
| 195 | + spec_modes <- model_info$mode |
| 196 | + if (!is.null(eng)) { |
| 197 | + spec_modes <- spec_modes[model_info$engine == eng] |
| 198 | + } |
| 199 | + spec_modes <- unique(c("unknown", spec_modes)) |
| 200 | + |
148 | 201 | if (is.null(mode) || length(mode) > 1) { |
149 | | - stop_incompatible_mode(spec_modes) |
| 202 | + stop_incompatible_mode(spec_modes, eng) |
150 | 203 | } else if (!(mode %in% spec_modes)) { |
151 | | - stop_incompatible_mode(spec_modes) |
| 204 | + stop_incompatible_mode(spec_modes, eng) |
152 | 205 | } |
| 206 | + |
| 207 | + # ---------------------------------------------------------------------------- |
| 208 | + # Check engine based on model and model |
| 209 | + |
| 210 | + # How check for compatibility with the chosen mode (if any) |
| 211 | + if (!is.null(mode) && mode != "unknown") { |
| 212 | + spec_engs <- spec_engs[model_info$mode == mode] |
| 213 | + } |
| 214 | + spec_engs <- unique(spec_engs) |
| 215 | + if (!is.null(eng) && !(eng %in% spec_engs)) { |
| 216 | + stop_incompatible_engine(spec_engs, mode) |
| 217 | + } |
| 218 | + |
153 | 219 | invisible(NULL) |
154 | 220 | } |
155 | 221 |
|
| 222 | + |
156 | 223 | check_engine_val <- function(eng) { |
157 | 224 | if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng)) |
158 | 225 | rlang::abort("Please supply a character string for an engine (e.g. `'lm'`).") |
@@ -625,8 +692,7 @@ get_dependency <- function(model) { |
625 | 692 | set_fit <- function(model, mode, eng, value) { |
626 | 693 | check_model_exists(model) |
627 | 694 | check_eng_val(eng) |
628 | | - check_mode_val(mode) |
629 | | - check_engine_val(eng) |
| 695 | + check_spec_mode_engine_val(model, eng, mode) |
630 | 696 | check_fit_info(value) |
631 | 697 |
|
632 | 698 | current <- get_model_env() |
@@ -692,8 +758,7 @@ get_fit <- function(model) { |
692 | 758 | set_pred <- function(model, mode, eng, type, value) { |
693 | 759 | check_model_exists(model) |
694 | 760 | check_eng_val(eng) |
695 | | - check_mode_val(mode) |
696 | | - check_engine_val(eng) |
| 761 | + check_spec_mode_engine_val(model, eng, mode) |
697 | 762 | check_pred_info(value, type) |
698 | 763 |
|
699 | 764 | current <- get_model_env() |
|
0 commit comments