Skip to content

Commit c36e7de

Browse files
committed
Merge branch 'master' of https://github.com/topepo/parsnip
2 parents 88d622c + 8fa11cf commit c36e7de

File tree

2 files changed

+73
-6
lines changed

2 files changed

+73
-6
lines changed

R/engines.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,13 @@ check_engine <- function(object) {
4242
#' @importFrom utils installed.packages
4343
check_installs <- function(x) {
4444
lib_inst <- rownames(installed.packages())
45-
if(length(x$method$library) > 0) {
45+
if (length(x$method$library) > 0) {
4646
is_inst <- x$method$library %in% lib_inst
47-
if(any(!is_inst)) {
48-
stop("This engine requires some package installs: ",
49-
paste0("'", x$method$library[!is_inst], "'", collapse = ", "))
47+
if (any(!is_inst)) {
48+
stop(
49+
"This engine requires some package installs: ",
50+
paste0("'", x$method$library[!is_inst], "'", collapse = ", ")
51+
)
5052
}
5153
}
5254
}

R/rand_forest.R

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,37 @@ get_randomForest_regression <- function () {
205205
list(library = libs, interface = interface, fit = fit, protect = protect)
206206
}
207207

208+
get_sparklyr_regression <- function () {
209+
libs <- "sparklyr"
210+
interface <- "data.frame" # adjust this to something else
211+
protect = c("x", "formula", "label_col", "features_col")
212+
fit <-
213+
quote(
214+
ml_random_forest_regressor(
215+
x = x,
216+
formula = NULL,
217+
num_trees = 20L,
218+
subsampling_rate = 1,
219+
max_depth = 5L,
220+
min_instances_per_node = 1L,
221+
feature_subset_strategy = "auto",
222+
impurity = "variance",
223+
min_info_gain = 0,
224+
max_bins = 32L,
225+
seed = NULL,
226+
checkpoint_interval = 10L,
227+
cache_node_ids = FALSE,
228+
max_memory_in_mb = 256L,
229+
features_col = "features",
230+
label_col = "label",
231+
prediction_col = "prediction",
232+
uid = random_string("random_forest_regressor_"),
233+
...
234+
)
235+
)
236+
list(library = libs, interface = interface, fit = fit, protect = protect)
237+
}
238+
208239

209240
get_ranger_classification <- function () {
210241
libs <- "ranger"
@@ -281,6 +312,40 @@ get_randomForest_classification <- function () {
281312
list(library = libs, interface = interface, fit = fit, protect = protect)
282313
}
283314

315+
get_sparklyr_regression <- function () {
316+
libs <- "sparklyr"
317+
interface <- "data.frame" # adjust this to something else
318+
protect = c("x", "formula", "label_col", "features_col")
319+
fit <-
320+
quote(
321+
ml_random_forest_classifier(
322+
x = x,
323+
formula = NULL,
324+
num_trees = 20L,
325+
subsampling_rate = 1,
326+
max_depth = 5L,
327+
min_instances_per_node = 1L,
328+
feature_subset_strategy = "auto",
329+
impurity = "gini",
330+
min_info_gain = 0,
331+
max_bins = 32L,
332+
seed = NULL,
333+
thresholds = NULL,
334+
checkpoint_interval = 10L,
335+
cache_node_ids = FALSE,
336+
max_memory_in_mb = 256L,
337+
features_col = "features",
338+
label_col = "label",
339+
prediction_col = "prediction",
340+
probability_col = "probability",
341+
raw_prediction_col = "rawPrediction",
342+
uid = random_string("random_forest_classifier_"),
343+
...
344+
)
345+
)
346+
list(library = libs, interface = interface, fit = fit, protect = protect)
347+
}
348+
284349

285350
###################################################################
286351

@@ -382,7 +447,7 @@ update.rand_forest <-
382447
rand_forest_arg_key <- data.frame(
383448
randomForest = c("mtry", "ntree", "nodesize"),
384449
ranger = c("mtry", "num.trees", "min.node.size"),
385-
spark =
450+
sparklyr =
386451
c("feature_subset_strategy", "num_trees", "min_instances_per_node"),
387452
stringsAsFactors = FALSE,
388453
row.names = c("mtry", "trees", "min_n")
@@ -393,7 +458,7 @@ rand_forest_modes <- c("classification", "regression", "unknown")
393458
rand_forest_engines <- data.frame(
394459
ranger = c(TRUE, TRUE, FALSE),
395460
randomForest = c(TRUE, TRUE, FALSE),
396-
spark = c(TRUE, TRUE, FALSE),
461+
sparklyr = c(TRUE, TRUE, FALSE),
397462
row.names = c("classification", "regression", "unknown")
398463
)
399464

0 commit comments

Comments
 (0)