11# Lazily registered in .onLoad()
2+ # Unit tests are in extratests
3+ # nocov start
24tunable_model_spec <- function (x , ... ) {
35 mod_env <- rlang :: ns_env(" parsnip" )$ parsnip
46
@@ -141,11 +143,13 @@ brulee_engine_args <-
141143 tibble :: tibble(
142144 name = c(
143145 " batch_size" ,
144- " class_weights"
146+ " class_weights" ,
147+ " mixture"
145148 ),
146149 call_info = list (
147- list (pkg = " dials" , fun = " batch_size" , range = c(5 , 10 )),
148- list (pkg = " dials" , fun = " class_weights" )
150+ list (pkg = " dials" , fun = " batch_size" , range = c(3 , 10 )),
151+ list (pkg = " dials" , fun = " class_weights" ),
152+ list (pkg = " dials" , fun = " mixture" )
149153 ),
150154 source = " model_spec" ,
151155 component = " mlp" ,
@@ -160,6 +164,8 @@ tunable_linear_reg <- function(x, ...) {
160164 if (x $ engine == " glmnet" ) {
161165 res $ call_info [res $ name == " mixture" ] <-
162166 list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
167+ } else if (x $ engine == " brulee" ) {
168+ res <- add_engine_parameters(res , brulee_engine_args )
163169 }
164170 res
165171}
@@ -170,6 +176,8 @@ tunable_logistic_reg <- function(x, ...) {
170176 if (x $ engine == " glmnet" ) {
171177 res $ call_info [res $ name == " mixture" ] <-
172178 list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
179+ } else if (x $ engine == " brulee" ) {
180+ res <- add_engine_parameters(res , brulee_engine_args )
173181 }
174182 res
175183}
@@ -180,6 +188,8 @@ tunable_multinomial_reg <- function(x, ...) {
180188 if (x $ engine == " glmnet" ) {
181189 res $ call_info [res $ name == " mixture" ] <-
182190 list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
191+ } else if (x $ engine == " brulee" ) {
192+ res <- add_engine_parameters(res , brulee_engine_args )
183193 }
184194 res
185195}
@@ -191,6 +201,8 @@ tunable_boost_tree <- function(x, ...) {
191201 res <- add_engine_parameters(res , xgboost_engine_args )
192202 res $ call_info [res $ name == " sample_size" ] <-
193203 list (list (pkg = " dials" , fun = " sample_prop" ))
204+ res $ call_info [res $ name == " learn_rate" ] <-
205+ list (list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 2 )))
194206 } else {
195207 if (x $ engine == " C5.0" ) {
196208 res <- add_engine_parameters(res , c5_boost_engine_args )
@@ -249,7 +261,10 @@ tunable_mlp <- function(x, ...) {
249261 res <- NextMethod()
250262 if (x $ engine == " brulee" ) {
251263 res <- add_engine_parameters(res , brulee_engine_args )
264+ res $ call_info [res $ name == " learn_rate" ] <-
265+ list (list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 2 )))
252266 }
253267 res
254268}
255269
270+ # nocov end
0 commit comments