Skip to content

Commit 1e945f4

Browse files
committed
make predictions respect the event_level
1 parent b79ff6d commit 1e945f4

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

R/boost_tree.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,17 @@ as_xgb_data <- function(x, y, validation = 0, event_level = "first", ...) {
472472

473473
list(data = dat, watchlist = wlist)
474474
}
475+
476+
get_event_level <- function(model_spec){
477+
if ("event_level" %in% names(model_spec$eng_args)) {
478+
event_level <- get_expr(model_spec$eng_args$event_level)
479+
} else {
480+
# "first" is the default for as_xgb_data() and xgb_train()
481+
event_level <- "first"
482+
}
483+
event_level
484+
}
485+
475486
#' @importFrom purrr map_df
476487
#' @export
477488
#' @rdname multi_predict

R/boost_tree_data.R

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,12 @@ set_pred(
158158
pre = NULL,
159159
post = function(x, object) {
160160
if (is.vector(x)) {
161-
x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1])
161+
event_level <- get_event_level(object$spec)
162+
if (event_level == "first") {
163+
x <- ifelse(x >= 0.5, object$lvl[1], object$lvl[2])
164+
} else {
165+
x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1])
166+
}
162167
} else {
163168
x <- object$lvl[apply(x, 1, which.max)]
164169
}
@@ -178,7 +183,12 @@ set_pred(
178183
pre = NULL,
179184
post = function(x, object) {
180185
if (is.vector(x)) {
181-
x <- tibble(v1 = 1 - x, v2 = x)
186+
event_level <- get_event_level(object$spec)
187+
if (event_level == "first") {
188+
x <- tibble(v1 = x, v2 = 1 - x)
189+
} else {
190+
x <- tibble(v1 = 1 - x, v2 = x)
191+
}
182192
} else {
183193
x <- as_tibble(x, .name_repair = "minimal")
184194
}

0 commit comments

Comments
 (0)