Skip to content

Commit 859f60f

Browse files
authored
better tuning parameters for brulee (#665)
1 parent 1f51e39 commit 859f60f

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

R/mlp_data.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ set_model_arg(
420420
eng = "brulee",
421421
parsnip = "learn_rate",
422422
original = "learn_rate",
423-
func = list(pkg = "dials", fun = "learn_rate"),
423+
func = list(pkg = "dials", fun = "learn_rate", range = c(-2.5, -0.5)),
424424
has_submodel = FALSE
425425
)
426426

@@ -448,7 +448,7 @@ set_model_arg(
448448
eng = "brulee",
449449
parsnip = "activation",
450450
original = "activation",
451-
func = list(pkg = "dials", fun = "activation"),
451+
func = list(pkg = "dials", fun = "activation", values = c('relu', 'elu', 'tanh')),
452452
has_submodel = FALSE
453453
)
454454

R/tunable.R

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,21 @@ earth_engine_args <-
137137
component_id = "engine"
138138
)
139139

140+
brulee_engine_args <-
141+
tibble::tibble(
142+
name = c(
143+
"batch_size",
144+
"class_weights"
145+
),
146+
call_info = list(
147+
list(pkg = "dials", fun = "batch_size", range = c(5, 10)),
148+
list(pkg = "dials", fun = "class_weights")
149+
),
150+
source = "model_spec",
151+
component = "mlp",
152+
component_id = "engine"
153+
)
154+
140155
# ------------------------------------------------------------------------------
141156

142157
# Lazily registered in .onLoad()
@@ -227,3 +242,14 @@ tunable_svm_poly <- function(x, ...) {
227242
}
228243
res
229244
}
245+
246+
247+
# Lazily registered in .onLoad()
248+
tunable_mlp <- function(x, ...) {
249+
res <- NextMethod()
250+
if (x$engine == "brulee") {
251+
res <- add_engine_parameters(res, brulee_engine_args)
252+
}
253+
res
254+
}
255+

R/zzz.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
vctrs::s3_register("generics::tunable", "mars", tunable_mars)
4646
vctrs::s3_register("generics::tunable", "decision_tree", tunable_decision_tree)
4747
vctrs::s3_register("generics::tunable", "svm_poly", tunable_svm_poly)
48+
vctrs::s3_register("generics::tunable", "mlp", tunable_mlp)
4849
}
4950

5051
}

0 commit comments

Comments
 (0)