Skip to content

Commit 1909b5d

Browse files
committed
fixed spark models so that class labels (not integers) are in colnames
1 parent cedeba5 commit 1909b5d

File tree

3 files changed

+6
-8
lines changed

3 files changed

+6
-8
lines changed

R/aaa_spark_helpers.R

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
#' @importFrom dplyr starts_with rename rename_at vars funs
44
format_spark_probs <- function(results, object) {
55
results <- dplyr::select(results, starts_with("probability_"))
6-
results <- dplyr::rename_at(
7-
results,
8-
vars(starts_with("probability_")),
9-
funs(gsub("probability", "pred", .))
10-
)
11-
results
6+
p <- ncol(results)
7+
lvl <- paste0("probability_", 0:(p - 1))
8+
names(lvl) <- paste0("pred_", object$fit$.index_labels)
9+
results %>% rename(!!!syms(lvl))
1210
}
1311

1412
format_spark_class <- function(results, object) {

tests/testthat/test_logistic_reg_spark.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ test_that('spark execution', {
7979
regexp = NA
8080
)
8181

82-
expect_equal(colnames(spark_class_prob), c("pred_0", "pred_1"))
82+
expect_equal(colnames(spark_class_prob), c("pred_Yes", "pred_No"))
8383

8484
expect_equivalent(
8585
as.data.frame(spark_class_prob),

tests/testthat/test_multinom_reg_spark.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ test_that('spark execution', {
6969

7070
expect_equal(
7171
colnames(spark_class_prob),
72-
c("pred_0", "pred_1", "pred_2")
72+
c("pred_versicolor", "pred_virginica", "pred_setosa")
7373
)
7474

7575
expect_equivalent(

0 commit comments

Comments
 (0)