77form_form <-
88 function (object , control , env , ... ) {
99
10+ # prob rewrite this as simple subset/levels
11+ y_levels <- levels_from_formula(env $ formula , env $ data )
12+
1013 if (object $ mode == " classification" ) {
11- # prob rewrite this as simple subset/levels
12- y_levels <- levels_from_formula(env $ formula , env $ data )
1314 if (! inherits(env $ data , " tbl_spark" ) && is.null(y_levels ))
14- rlang :: abort(" For classification models, the outcome should be a factor." )
15- } else {
16- y_levels <- NULL
15+ rlang :: abort(" For a classification model, the outcome should be a factor." )
16+ } else if (object $ mode == " regression" ) {
17+ if (! inherits(env $ data , " tbl_spark" ) && ! is.null(y_levels ))
18+ rlang :: abort(" For a regression model, the outcome should be numeric." )
1719 }
1820
1921 object <- check_mode(object , y_levels )
@@ -57,11 +59,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
5759 rlang :: abort(" spark objects can only be used with the formula interface to `fit()`" )
5860
5961 object <- check_mode(object , levels(env $ y ))
60-
61- if (object $ mode == " classification" ) {
62- if (is.null(levels(env $ y )))
63- rlang :: abort(" For classification models, the outcome should be a factor." )
64- }
62+ check_outcome(env $ y , object )
6563
6664 encoding_info <-
6765 get_encoding(class(object )[1 ]) %> %
@@ -133,7 +131,10 @@ form_xy <- function(object, control, env,
133131 res <- list (lvl = levels_from_formula(env $ formula , env $ data ), spec = object )
134132 if (object $ mode == " classification" ) {
135133 if (is.null(res $ lvl ))
136- rlang :: abort(" For classification models, the outcome should be a factor." )
134+ rlang :: abort(" For a classification model, the outcome should be a factor." )
135+ } else if (object $ mode == " regression" ) {
136+ if (! is.null(res $ lvl ))
137+ rlang :: abort(" For a regression model, the outcome should be numeric." )
137138 }
138139
139140 res <- xy_xy(
@@ -153,10 +154,7 @@ form_xy <- function(object, control, env,
153154
154155xy_form <- function (object , env , control , ... ) {
155156
156- if (object $ mode == " classification" ) {
157- if (is.null(levels(env $ y )))
158- rlang :: abort(" For classification models, the outcome should be a factor." )
159- }
157+ check_outcome(env $ y , object )
160158
161159 encoding_info <-
162160 get_encoding(class(object )[1 ]) %> %
0 commit comments