1-
21# Unit tests are in extratests
32# nocov start
43
54# ' @export
65tunable.model_spec <- function (x , ... ) {
7-
86 mod_env <- get_model_env()
97
108 if (is.null(x $ engine )) {
@@ -13,9 +11,14 @@ tunable.model_spec <- function(x, ...) {
1311
1412 arg_name <- paste0(mod_type(x ), " _args" )
1513 if (! (any(arg_name == names(mod_env )))) {
16- stop(" The `parsnip` model database doesn't know about the arguments for " ,
17- " model `" , mod_type(x ), " `. Was it registered?" ,
18- sep = " " , call. = FALSE )
14+ stop(
15+ " The `parsnip` model database doesn't know about the arguments for " ,
16+ " model `" ,
17+ mod_type(x ),
18+ " `. Was it registered?" ,
19+ sep = " " ,
20+ call. = FALSE
21+ )
1922 }
2023
2124 arg_vals <- mod_env [[arg_name ]]
@@ -28,7 +31,10 @@ tunable.model_spec <- function(x, ...) {
2831
2932 extra_args_tbl <-
3033 tibble :: new_tibble(
31- list (name = extra_args , call_info = vector(" list" , vctrs :: vec_size(extra_args ))),
34+ list (
35+ name = extra_args ,
36+ call_info = vector(" list" , vctrs :: vec_size(extra_args ))
37+ ),
3238 nrow = vctrs :: vec_size(extra_args )
3339 )
3440
@@ -57,7 +63,7 @@ add_engine_parameters <- function(pset, engines) {
5763 is_engine_param <- pset $ name %in% engines $ name
5864 if (any(is_engine_param )) {
5965 engine_names <- pset $ name [is_engine_param ]
60- pset <- pset [! is_engine_param ,]
66+ pset <- pset [! is_engine_param , ]
6167 pset <-
6268 dplyr :: bind_rows(pset , engines | > dplyr :: filter(name %in% engines $ name ))
6369 }
@@ -213,9 +219,22 @@ tune_sched <- c("none", "decay_time", "decay_expo", "cyclic", "step")
213219
214220brulee_mlp_args <-
215221 tibble :: tibble(
216- name = c(' epochs' , ' hidden_units' , ' hidden_units_2' , ' activation' , ' activation_2' ,
217- ' penalty' , ' mixture' , ' dropout' , ' learn_rate' , ' momentum' , ' batch_size' ,
218- ' class_weights' , ' stop_iter' , ' rate_schedule' ),
222+ name = c(
223+ ' epochs' ,
224+ ' hidden_units' ,
225+ ' hidden_units_2' ,
226+ ' activation' ,
227+ ' activation_2' ,
228+ ' penalty' ,
229+ ' mixture' ,
230+ ' dropout' ,
231+ ' learn_rate' ,
232+ ' momentum' ,
233+ ' batch_size' ,
234+ ' class_weights' ,
235+ ' stop_iter' ,
236+ ' rate_schedule'
237+ ),
219238 call_info = list (
220239 list (pkg = " dials" , fun = " epochs" , range = c(5L , 500L )),
221240 list (pkg = " dials" , fun = " hidden_units" , range = c(2L , 50L )),
@@ -225,9 +244,9 @@ brulee_mlp_args <-
225244 list (pkg = " dials" , fun = " penalty" ),
226245 list (pkg = " dials" , fun = " mixture" ),
227246 list (pkg = " dials" , fun = " dropout" ),
228- list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 5 )),
229- list (pkg = " dials" , fun = " momentum" , range = c(0.50 , 0.95 )),
230- list (pkg = " dials" , fun = " batch_size" ),
247+ list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 5 )),
248+ list (pkg = " dials" , fun = " momentum" , range = c(0.00 , 0.99 )),
249+ list (pkg = " dials" , fun = " batch_size" , range = c( 3L , 8L ) ),
231250 list (pkg = " dials" , fun = " class_weights" ),
232251 list (pkg = " dials" , fun = " stop_iter" ),
233252 list (pkg = " dials" , fun = " rate_schedule" , values = tune_sched )
@@ -237,8 +256,13 @@ brulee_mlp_args <-
237256
238257brulee_mlp_only_args <-
239258 tibble :: tibble(
240- name =
241- c(' hidden_units' , ' hidden_units_2' , ' activation' , ' activation_2' , ' dropout' )
259+ name = c(
260+ ' hidden_units' ,
261+ ' hidden_units_2' ,
262+ ' activation' ,
263+ ' activation_2' ,
264+ ' dropout'
265+ )
242266 )
243267
244268# ------------------------------------------------------------------------------
@@ -256,7 +280,11 @@ tunable.linear_reg <- function(x, ...) {
256280 dplyr :: filter(name != " class_weights" ) | >
257281 dplyr :: mutate(
258282 component = " linear_reg" ,
259- component_id = ifelse(name %in% names(formals(" linear_reg" )), " main" , " engine" )
283+ component_id = ifelse(
284+ name %in% names(formals(" linear_reg" )),
285+ " main" ,
286+ " engine"
287+ )
260288 ) | >
261289 dplyr :: select(name , call_info , source , component , component_id )
262290 }
@@ -277,7 +305,11 @@ tunable.logistic_reg <- function(x, ...) {
277305 dplyr :: anti_join(brulee_mlp_only_args , by = " name" ) | >
278306 dplyr :: mutate(
279307 component = " logistic_reg" ,
280- component_id = ifelse(name %in% names(formals(" logistic_reg" )), " main" , " engine" )
308+ component_id = ifelse(
309+ name %in% names(formals(" logistic_reg" )),
310+ " main" ,
311+ " engine"
312+ )
281313 ) | >
282314 dplyr :: select(name , call_info , source , component , component_id )
283315 }
@@ -296,7 +328,11 @@ tunable.multinom_reg <- function(x, ...) {
296328 dplyr :: anti_join(brulee_mlp_only_args , by = " name" ) | >
297329 dplyr :: mutate(
298330 component = " multinom_reg" ,
299- component_id = ifelse(name %in% names(formals(" multinom_reg" )), " main" , " engine" )
331+ component_id = ifelse(
332+ name %in% names(formals(" multinom_reg" )),
333+ " main" ,
334+ " engine"
335+ )
300336 ) | >
301337 dplyr :: select(name , call_info , source , component , component_id )
302338 }
@@ -311,7 +347,7 @@ tunable.boost_tree <- function(x, ...) {
311347 res $ call_info [res $ name == " sample_size" ] <-
312348 list (list (pkg = " dials" , fun = " sample_prop" ))
313349 res $ call_info [res $ name == " learn_rate" ] <-
314- list (list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 2 )))
350+ list (list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 2 )))
315351 } else if (x $ engine == " C5.0" ) {
316352 res <- add_engine_parameters(res , c5_boost_engine_args )
317353 res $ call_info [res $ name == " trees" ] <-
@@ -357,9 +393,11 @@ tunable.decision_tree <- function(x, ...) {
357393 res <- add_engine_parameters(res , c5_tree_engine_args )
358394 } else if (x $ engine == " partykit" ) {
359395 res <-
360- add_engine_parameters(res ,
361- partykit_engine_args | >
362- dplyr :: mutate(component = " decision_tree" ))
396+ add_engine_parameters(
397+ res ,
398+ partykit_engine_args | >
399+ dplyr :: mutate(component = " decision_tree" )
400+ )
363401 }
364402 res
365403}
@@ -386,7 +424,7 @@ tunable.mlp <- function(x, ...) {
386424 ) | >
387425 dplyr :: select(name , call_info , source , component , component_id )
388426 if (x $ engine == " brulee" ) {
389- res <- res [! grepl(" _2" , res $ name ),]
427+ res <- res [! grepl(" _2" , res $ name ), ]
390428 }
391429 }
392430 res
@@ -402,4 +440,3 @@ tunable.survival_reg <- function(x, ...) {
402440}
403441
404442# nocov end
405-
0 commit comments