|
21 | 21 | #' @param activation A single character string denoting the type of relationship |
22 | 22 | #' between the original predictors and the hidden unit layer. The activation |
23 | 23 | #' function between the hidden and output layers is automatically set to either |
24 | | -#' "linear" or "softmax" depending on the type of outcome. Possible values are: |
25 | | -#' "linear", "softmax", "relu", and "elu" |
| 24 | +#' "linear" or "softmax" depending on the type of outcome. Possible values |
| 25 | +#' depend on the engine being used. |
26 | 26 | #' |
27 | 27 | #' @templateVar modeltype mlp |
28 | 28 | #' @template spec-details |
@@ -142,24 +142,6 @@ check_args.mlp <- function(object) { |
142 | 142 | if (args$dropout > 0 & args$penalty > 0) |
143 | 143 | rlang::abort("Both weight decay and dropout should not be specified.") |
144 | 144 |
|
145 | | - |
146 | | - if (object$engine == "brulee") { |
147 | | - act_funs <- c("linear", "relu", "elu", "tanh") |
148 | | - } else if (object$engine == "keras") { |
149 | | - act_funs <- c("linear", "softmax", "relu", "elu") |
150 | | - } else if (object$engine == "h2o") { |
151 | | - act_funs <- c("relu", "tanh") |
152 | | - } |
153 | | - |
154 | | - if (is.character(args$activation)) { |
155 | | - if (!any(args$activation %in% c(act_funs))) { |
156 | | - rlang::abort( |
157 | | - glue::glue("`activation` should be one of: ", |
158 | | - glue::glue_collapse(glue::glue("'{act_funs}'"), sep = ", ")) |
159 | | - ) |
160 | | - } |
161 | | - } |
162 | | - |
163 | 145 | invisible(object) |
164 | 146 | } |
165 | 147 |
|
@@ -210,6 +192,9 @@ keras_mlp <- |
210 | 192 | seeds = sample.int(10^5, size = 3), |
211 | 193 | ...) { |
212 | 194 |
|
| 195 | + act_funs <- c("linear", "softmax", "relu", "elu") |
| 196 | + rlang::arg_match(activation, act_funs,) |
| 197 | + |
213 | 198 | if (penalty > 0 & dropout > 0) { |
214 | 199 | rlang::abort("Please use either dropoput or weight decay.", call. = FALSE) |
215 | 200 | } |
|
0 commit comments