Skip to content

Commit 07961a0

Browse files
authored
fix model fit for spark tbls (#1047)
1 parent 2752453 commit 07961a0

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* Fixed bug in fitting some model types with the `"spark"` engine (#1045).
4+
35
* Fixed issue in `mlp()` metadata where the `stop_iter` engine argument had been mistakenly protected for the `"brulee"` engine. (#1050)
46

57
* `.filter_eval_time()` was moved to the survival standalone file.

R/arguments.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,11 @@ min_cols <- function(num_cols, source) {
328328
#' @export
329329
#' @rdname min_cols
330330
min_rows <- function(num_rows, source, offset = 0) {
331-
n <- nrow(source)
331+
if (inherits(source, "tbl_spark")) {
332+
n <- nrow_spark(source)
333+
} else {
334+
n <- nrow(source)
335+
}
332336

333337
if (num_rows > n - offset) {
334338
msg <- paste0(num_rows, " samples were requested but there were ", n,
@@ -340,3 +344,7 @@ min_rows <- function(num_rows, source, offset = 0) {
340344
as.integer(num_rows)
341345
}
342346

347+
nrow_spark <- function(source) {
348+
rlang::check_installed("sparklyr")
349+
sparklyr::sdf_nrow(source)
350+
}

tests/testthat/test_boost_tree.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ test_that('bad input', {
2828
## -----------------------------------------------------------------------------
2929

3030
test_that('argument checks for data dimensions', {
31+
skip_if_not_installed("sparklyr")
32+
library(sparklyr)
33+
skip_if(nrow(spark_installed_versions()) == 0)
3134

3235
spec <-
3336
boost_tree(mtry = 1000, min_n = 1000, trees = 5) %>%
@@ -36,6 +39,10 @@ test_that('argument checks for data dimensions', {
3639

3740
args <- translate(spec)$method$fit$args
3841
expect_equal(args$min_instances_per_node, expr(min_rows(1000, x)))
42+
43+
sc = spark_connect(master = "local")
44+
cars = copy_to(sc, mtcars, overwrite = TRUE)
45+
expect_equal(min_rows(10, cars), 10)
3946
})
4047

4148
test_that('boost_tree can be fit with 1 predictor if validation is used', {

0 commit comments

Comments
 (0)