Skip to content

Commit 728dc12

Browse files
authored
brulee engine args (#810)
* update new engine args * update tunable methods for brulee engines * clean up a bunch of tests * doc update * add a note about the dials version
1 parent 312d191 commit 728dc12

File tree

67 files changed

+638
-197
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+638
-197
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44

55
* `fit_xy()` now fails when the model mode is unknown.
66

7+
* brulee engine-specific tuning parameters were updated. These changes can be used with dials version > 1.0.0.
8+
79
* `fit()` and `fit_xy()` doesn't error anymore if `control` argument isn't a `control_parsnip()` object. Will work as long as the object passed to `control` includes the same elements as `control_parsnip()`.
810

11+
912
# parsnip 1.0.1
1013

1114
* Enabled passing additional engine arguments with the xgboost `boost_tree()` engine. To supply engine-specific arguments that are documented in `xgboost::xgb.train()` as arguments to be passed via `params`, supply the list elements directly as named arguments to `set_engine()`. Read more in `?details_boost_tree_xgboost` (#787).

R/tunable.R

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -158,23 +158,38 @@ earth_engine_args <-
158158
component_id = "engine"
159159
)
160160

161-
brulee_engine_args <-
162-
tibble::tibble(
163-
name = c(
164-
"batch_size",
165-
"class_weights",
166-
"mixture"
167-
),
168-
call_info = list(
169-
list(pkg = "dials", fun = "batch_size", range = c(3, 10)),
170-
list(pkg = "dials", fun = "class_weights"),
171-
list(pkg = "dials", fun = "mixture")
172-
),
173-
source = "model_spec",
174-
component = "mlp",
175-
component_id = "engine"
161+
brulee_mlp_engine_args <-
162+
tibble::tribble(
163+
~name, ~call_info,
164+
"momentum", list(pkg = "dials", fun = "momentum", range = c(0.5, 0.95)),
165+
"batch_size", list(pkg = "dials", fun = "batch_size", range = c(3, 10)),
166+
"stop_iter", list(pkg = "dials", fun = "stop_iter"),
167+
"class_weights", list(pkg = "dials", fun = "class_weights"),
168+
"decay", list(pkg = "dials", fun = "rate_decay"),
169+
"initial", list(pkg = "dials", fun = "rate_initial"),
170+
"largest", list(pkg = "dials", fun = "rate_largest"),
171+
"rate_schedule", list(pkg = "dials", fun = "rate_schedule"),
172+
"step_size", list(pkg = "dials", fun = "rate_step_size"),
173+
"steps", list(pkg = "dials", fun = "rate_steps")
174+
) %>%
175+
dplyr::mutate(,
176+
source = "model_spec",
177+
component = "mlp",
178+
component_id = "engine"
176179
)
177180

181+
brulee_linear_engine_args <-
182+
brulee_mlp_engine_args %>%
183+
dplyr::filter(name %in% c("momentum", "batch_size", "stop_iter"))
184+
185+
brulee_logistc_engine_args <-
186+
brulee_mlp_engine_args %>%
187+
dplyr::filter(name %in% c("momentum", "batch_size", "stop_iter", "class_weights"))
188+
189+
brulee_multinomial_engine_args <-
190+
brulee_mlp_engine_args %>%
191+
dplyr::filter(name %in% c("momentum", "batch_size", "stop_iter", "class_weights"))
192+
178193
# ------------------------------------------------------------------------------
179194

180195
# Lazily registered in .onLoad()
@@ -184,7 +199,7 @@ tunable_linear_reg <- function(x, ...) {
184199
res$call_info[res$name == "mixture"] <-
185200
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
186201
} else if (x$engine == "brulee") {
187-
res <- add_engine_parameters(res, brulee_engine_args)
202+
res <- add_engine_parameters(res, brulee_linear_engine_args)
188203
}
189204
res
190205
}
@@ -196,7 +211,7 @@ tunable_logistic_reg <- function(x, ...) {
196211
res$call_info[res$name == "mixture"] <-
197212
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
198213
} else if (x$engine == "brulee") {
199-
res <- add_engine_parameters(res, brulee_engine_args)
214+
res <- add_engine_parameters(res, brulee_logistc_engine_args)
200215
}
201216
res
202217
}
@@ -208,7 +223,7 @@ tunable_multinomial_reg <- function(x, ...) {
208223
res$call_info[res$name == "mixture"] <-
209224
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
210225
} else if (x$engine == "brulee") {
211-
res <- add_engine_parameters(res, brulee_engine_args)
226+
res <- add_engine_parameters(res, brulee_mlp_engine_args)
212227
}
213228
res
214229
}
@@ -223,14 +238,14 @@ tunable_boost_tree <- function(x, ...) {
223238
res$call_info[res$name == "learn_rate"] <-
224239
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1/2)))
225240
} else if (x$engine == "C5.0") {
226-
res <- add_engine_parameters(res, c5_boost_engine_args)
227-
res$call_info[res$name == "trees"] <-
228-
list(list(pkg = "dials", fun = "trees", range = c(1, 100)))
229-
res$call_info[res$name == "sample_size"] <-
230-
list(list(pkg = "dials", fun = "sample_prop"))
241+
res <- add_engine_parameters(res, c5_boost_engine_args)
242+
res$call_info[res$name == "trees"] <-
243+
list(list(pkg = "dials", fun = "trees", range = c(1, 100)))
244+
res$call_info[res$name == "sample_size"] <-
245+
list(list(pkg = "dials", fun = "sample_prop"))
231246
} else if (x$engine == "lightgbm") {
232-
res$call_info[res$name == "sample_size"] <-
233-
list(list(pkg = "dials", fun = "sample_prop"))
247+
res$call_info[res$name == "sample_size"] <-
248+
list(list(pkg = "dials", fun = "sample_prop"))
234249
}
235250
res
236251
}
@@ -286,11 +301,14 @@ tunable_svm_poly <- function(x, ...) {
286301
tunable_mlp <- function(x, ...) {
287302
res <- NextMethod()
288303
if (x$engine == "brulee") {
289-
res <- add_engine_parameters(res, brulee_engine_args)
304+
res <- add_engine_parameters(res, brulee_mlp_engine_args)
290305
res$call_info[res$name == "learn_rate"] <-
291306
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1/2)))
307+
res$call_info[res$name == "epochs"] <-
308+
list(list(pkg = "dials", fun = "epochs", range = c(5L, 500L)))
292309
}
293310
res
294311
}
295312

296313
# nocov end
314+

man/details_auto_ml_h2o.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/details_bag_mars_earth.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/details_bart_dbarts.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/details_boost_tree_h2o.Rd

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/details_boost_tree_lightgbm.Rd

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/details_boost_tree_xgboost.Rd

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/details_discrim_flexible_earth.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/details_discrim_linear_MASS.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)