@@ -322,19 +322,29 @@ new_model_spec <- function(cls, args, eng_args, mode, user_specified_mode = TRUE
322322check_outcome <- function (y , spec ) {
323323 if (spec $ mode == " unknown" ) {
324324 return (invisible (NULL ))
325- } else if (spec $ mode == " regression" ) {
326- if (! all(map_lgl(y , is.numeric ))) {
325+ }
326+
327+ if (spec $ mode == " regression" ) {
328+ outcome_is_numeric <- if (is.atomic(y )) {is.numeric(y )} else {all(map_lgl(y , is.numeric ))}
329+ if (! outcome_is_numeric ) {
327330 rlang :: abort(" For a regression model, the outcome should be numeric." )
328331 }
329- } else if (spec $ mode == " classification" ) {
330- if (! all(map_lgl(y , is.factor ))) {
332+ }
333+
334+ if (spec $ mode == " classification" ) {
335+ outcome_is_factor <- if (is.atomic(y )) {is.factor(y )} else {all(map_lgl(y , is.factor ))}
336+ if (! outcome_is_factor ) {
331337 rlang :: abort(" For a classification model, the outcome should be a factor." )
332338 }
333- } else if (spec $ mode == " censored regression" ) {
334- if (! inherits(y , " Surv" )) {
339+ }
340+
341+ if (spec $ mode == " censored regression" ) {
342+ outcome_is_surv <- inherits(y , " Surv" )
343+ if (! outcome_is_surv ) {
335344 rlang :: abort(" For a censored regression model, the outcome should be a `Surv` object." )
336345 }
337346 }
347+
338348 invisible (NULL )
339349}
340350
0 commit comments