Skip to content

Commit 933deb2

Browse files
authored
check outcome class conditional on whether y is atomic (#835)
1 parent fddc746 commit 933deb2

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

R/misc.R

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -322,19 +322,29 @@ new_model_spec <- function(cls, args, eng_args, mode, user_specified_mode = TRUE
322322
check_outcome <- function(y, spec) {
323323
if (spec$mode == "unknown") {
324324
return(invisible(NULL))
325-
} else if (spec$mode == "regression") {
326-
if (!all(map_lgl(y, is.numeric))) {
325+
}
326+
327+
if (spec$mode == "regression") {
328+
outcome_is_numeric <- if (is.atomic(y)) {is.numeric(y)} else {all(map_lgl(y, is.numeric))}
329+
if (!outcome_is_numeric) {
327330
rlang::abort("For a regression model, the outcome should be numeric.")
328331
}
329-
} else if (spec$mode == "classification") {
330-
if (!all(map_lgl(y, is.factor))) {
332+
}
333+
334+
if (spec$mode == "classification") {
335+
outcome_is_factor <- if (is.atomic(y)) {is.factor(y)} else {all(map_lgl(y, is.factor))}
336+
if (!outcome_is_factor) {
331337
rlang::abort("For a classification model, the outcome should be a factor.")
332338
}
333-
} else if (spec$mode == "censored regression") {
334-
if (!inherits(y, "Surv")) {
339+
}
340+
341+
if (spec$mode == "censored regression") {
342+
outcome_is_surv <- inherits(y, "Surv")
343+
if (!outcome_is_surv) {
335344
rlang::abort("For a censored regression model, the outcome should be a `Surv` object.")
336345
}
337346
}
347+
338348
invisible(NULL)
339349
}
340350

0 commit comments

Comments
 (0)