Skip to content

Commit da37f1e

Browse files
committed
allow y as df in fit_xy for #129
1 parent 5bc61d8 commit da37f1e

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

NEWS.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## New Features
44

55
* A "null model" is now available that fits a predictor-free model (using the mean of the outcome for regression or the mode for classification).
6+
* `fit_xy()` can take a single column data frame or matrix for `y` without error
67

78
## Other Changes
89

@@ -38,7 +39,7 @@ First CRAN release
3839

3940
# parsnip 0.0.0.9005
4041

41-
* The engine, and any associated arguments, are now specified using `set_engine`. There is no `engine` argument
42+
* The engine, and any associated arguments, are now specified using `set_engine()`. There is no `engine` argument
4243

4344

4445
# parsnip 0.0.0.9004
@@ -64,7 +65,7 @@ First CRAN release
6465

6566
# parsnip 0.0.0.9000
6667

67-
* The `fit` interface was previously used to cover both the x/y interface as well as the formula interface. Now, `fit` is the formula interface and [`fit_xy` is for the x/y interface](https://github.com/topepo/parsnip/issues/33).
68+
* The `fit` interface was previously used to cover both the x/y interface as well as the formula interface. Now, `fit()` is the formula interface and [`fit_xy()` is for the x/y interface](https://github.com/topepo/parsnip/issues/33).
6869
* Added a `NEWS.md` file to track changes to the package.
6970
* `predict` methods were [overhauled](https://github.com/topepo/parsnip/issues/34) to be [consistent](https://github.com/topepo/parsnip/issues/41).
7071
* MARS was added.

R/fit.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,14 @@ fit_xy.model_spec <-
180180
if (any(names(dots) == "engine"))
181181
stop("Use `set_engine()` to supply the engine.", call. = FALSE)
182182

183+
if (object$engine != "spark" & NCOL(y) == 1 & !(is.vector(y) | is.factor(y))) {
184+
if (is.matrix(y)) {
185+
y <- y[, 1]
186+
} else {
187+
y <- y[[1]]
188+
}
189+
}
190+
183191
cl <- match.call(expand.dots = TRUE)
184192
eval_env <- rlang::env()
185193
eval_env$x <- x

tests/testthat/test_fit_interfaces.R

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ test_that('good args', {
2323
expect_equal( tester(NULL, formula = f, data = iris, model = rmod), "formula")
2424
expect_equal(tester_xy(NULL, x = iris, y = iris, model = rmod), "data.frame")
2525
expect_equal( tester(NULL, f, data = iris, model = rmod), "formula")
26-
expect_equal( tester(NULL, f, data = sprk, model = rmod), "formula")
26+
expect_equal( tester(NULL, f, data = sprk, model = rmod), "formula")
2727
})
2828

2929
#test_that('unnamed args', {
@@ -37,3 +37,26 @@ test_that('wrong args', {
3737
expect_error(tester(NULL, f, data = as.matrix(iris[, 1:4])))
3838
})
3939

40+
test_that('single column df for issue #129', {
41+
42+
expect_error(
43+
lm1 <-
44+
linear_reg() %>%
45+
set_engine("lm") %>%
46+
fit_xy(x = mtcars[, 2:4], y = mtcars[,1, drop = FALSE]),
47+
regexp = NA
48+
)
49+
expect_error(
50+
lm2 <-
51+
linear_reg() %>%
52+
set_engine("lm") %>%
53+
fit_xy(x = mtcars[, 2:4], y = as.matrix(mtcars)[,1, drop = FALSE]),
54+
regexp = NA
55+
)
56+
lm3 <-
57+
linear_reg() %>%
58+
set_engine("lm") %>%
59+
fit_xy(x = mtcars[, 2:4], y = mtcars$mpg)
60+
expect_equal(coef(lm1), coef(lm3))
61+
expect_equal(coef(lm2), coef(lm3))
62+
})

0 commit comments

Comments
 (0)