Skip to content

Commit 2fa853e

Browse files
topepohfrick
andauthored
Keep IPCW results in the list column format predicted by the predict() methods (#937)
* replace a few functions * remove older code * update docs and pass tolerance args to computations * version bump * move to vctrs replacements for tidyr functions * comments * make sure that the original first eval_time is still first in the filtered vector * re-doc * vec chopin' * extra bump * update snapshot for new pillar * no need to fiddle with call (could be `NULL` but not `FALSE`) * Apply suggestions from code review Co-authored-by: Hannah Frick <hfrick@users.noreply.github.com> * add back .filter_eval_time but at prediction time * more on truncation * warning for bad time points * typo fix * added more snapshots for warnings --------- Co-authored-by: Hannah Frick <hannah@rstudio.com> Co-authored-by: Hannah Frick <hfrick@users.noreply.github.com>
1 parent ef9d376 commit 2fa853e

File tree

9 files changed

+166
-99
lines changed

9 files changed

+166
-99
lines changed

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.0.4.9005
3+
Version: 1.0.4.9006
44
Authors@R: c(
55
person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "davis@posit.co", role = "aut"),
@@ -36,7 +36,7 @@ Imports:
3636
tibble (>= 2.1.1),
3737
tidyr (>= 1.3.0),
3838
utils,
39-
vctrs (>= 0.4.1),
39+
vctrs (>= 0.6.0),
4040
withr
4141
Suggests:
4242
C50,

R/ipcw.R

Lines changed: 103 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -22,36 +22,41 @@ trunc_probs <- function(probs, trunc = 0.01) {
2222
if (!is.null(eval_time)) {
2323
eval_time <- as.numeric(eval_time)
2424
}
25+
eval_time_0 <- eval_time
2526
# will still propagate nulls:
2627
eval_time <- eval_time[!is.na(eval_time)]
27-
eval_time <- unique(eval_time)
28-
eval_time <- sort(eval_time)
2928
eval_time <- eval_time[eval_time >= 0 & is.finite(eval_time)]
29+
eval_time <- unique(eval_time)
3030
if (fail && identical(eval_time, numeric(0))) {
3131
rlang::abort(
3232
"There were no usable evaluation times (finite, non-missing, and >= 0).",
3333
call = NULL
3434
)
3535
}
36+
if (!identical(eval_time, eval_time_0)) {
37+
diffs <- setdiff(eval_time_0, eval_time)
38+
msg <-
39+
cli::pluralize(
40+
"There {?was/were} {length(diffs)} inappropriate evaluation time point{?s} that {?was/were} removed.")
41+
rlang::warn(msg)
42+
}
3643
eval_time
3744
}
3845

39-
add_dot_row_to_weights <- function(dat, rows = NULL) {
40-
if (is.null(rows)) {
41-
dat <- add_rowindex(dat)
42-
} else {
43-
m <- length(rows)
44-
n <- nrow(dat)
45-
if (m != n) {
46-
rlang::abort(
47-
glue::glue(
48-
"The length of 'rows' ({m}) should be equal to the number of rows in 'data' ({n})"
49-
)
50-
)
51-
}
52-
dat$.row <- rows
46+
.check_pred_col <- function(x, call = rlang::env_parent()) {
47+
if (!any(names(x) == ".pred")) {
48+
rlang::abort("The input should have a list column called `.pred`.", call = call)
49+
}
50+
if (!is.list(x$.pred)) {
51+
rlang::abort("The input should have a list column called `.pred`.", call = call)
5352
}
54-
dat
53+
req_cols <- c(".eval_time", ".pred_survival")
54+
if (!all(req_cols %in% names(x$.pred[[1]]))) {
55+
msg <- paste0("The `.pred` tibbles should have columns: ",
56+
paste0("'", req_cols, "'", collapse = ", "))
57+
rlang::abort(msg, call = call)
58+
}
59+
invisible(NULL)
5560
}
5661

5762
.check_censor_model <- function(x) {
@@ -73,7 +78,7 @@ add_dot_row_to_weights <- function(dat, rows = NULL) {
7378
# We need to use the time of analysis to determine what time to use to evaluate
7479
# the IPCWs.
7580

76-
graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
81+
graf_weight_time_vec <- function(surv_obj, eval_time, eps = 10^-10) {
7782
event_time <- .extract_surv_time(surv_obj)
7883
status <- .extract_surv_status(surv_obj)
7984
is_event_before_t <- event_time <= eval_time & status == 1
@@ -85,15 +90,14 @@ graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
8590
weight_time <- rep(NA_real_, length(event_time))
8691

8792
# A real event prior to eval_time (Graf category 1)
88-
weight_time[is_event_before_t] <- event_time[is_event_before_t] - eps
93+
weight_time <- ifelse(is_event_before_t, event_time - eps, weight_time)
8994

9095
# Observed time greater than eval_time (Graf category 2)
91-
weight_time[is_censored] <- eval_time - eps
96+
weight_time <- ifelse(is_censored, eval_time - eps, weight_time)
9297

9398
weight_time <- ifelse(weight_time < 0, 0, weight_time)
9499

95-
res <- tibble::tibble(surv = surv_obj, weight_time = weight_time, eval_time)
96-
add_dot_row_to_weights(res, rows)
100+
weight_time
97101
}
98102

99103
# ------------------------------------------------------------------------------
@@ -102,24 +106,28 @@ graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
102106
#' The method of Graf _et al_ (1999) is used to compute weights at specific
103107
#' evaluation times that can be used to help measure a model's time-dependent
104108
#' performance (e.g. the time-dependent Brier score or the area under the ROC
105-
#' curve).
106-
#' @param data A data frame with a column containing a [survival::Surv()] object.
107-
#' @param predictors Not currently used. A potential future slot for models with
108-
#' informative censoring based on columns in `data`.
109-
#' @param rows An optional integer vector with length equal to the number of
110-
#' rows in `data` that is used to index the original data. The default is to
111-
#' use a fresh index on data (i.e. `1:nrow(data)`).
112-
#' @param eval_time A vector of finite, non-negative times at which to
113-
#' compute the probability of censoring and the corresponding weights.
109+
#' curve). This is an internal function.
110+
#'
111+
#' @param predictions A data frame with a column containing a [survival::Surv()]
112+
#' object as well as a list column called `.pred` that contains the data
113+
#' structure produced by [predict.model_fit()].
114+
#' @param cens_predictors Not currently used. A potential future slot for models with
115+
#' informative censoring based on columns in `predictions`.
114116
#' @param object A fitted parsnip model object or fitted workflow with a mode
115117
#' of "censored regression".
116118
#' @param trunc A potential lower bound for the probability of censoring to avoid
117119
#' very large weight values.
118120
#' @param eps A small value that is subtracted from the evaluation time when
119121
#' computing the censoring probabilities. See Details below.
120-
#' @return A tibble with columns `.row`, `eval_time`, `.prob_cens` (the
121-
#' probability of being censored just prior to the evaluation time), and
122-
#' `.weight_cens` (the inverse probability of censoring weight).
122+
#' @return The same data are returned with the `pred` tibbles containing
123+
#' several new columns:
124+
#'
125+
#' - `.weight_time`: the time at which the inverse censoring probability weights
126+
#' are computed. This is a function of the observed time and the time of
127+
#' analysis (i.e., `eval_time`). See Details for more information.
128+
#' - `.pred_censored`: the probability of being censored at `.weight_time`.
129+
#' - `.weight_censored`: The inverse of the censoring probability.
130+
#'
123131
#' @details
124132
#'
125133
#' A probability that the data are censored immediately prior to a specific
@@ -155,13 +163,21 @@ graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
155163
#' The `eps` argument is used to avoid information leakage when computing the
156164
#' censoring probability. Subtracting a small number avoids using data that
157165
#' would not be known at the time of prediction. For example, if we are making
158-
#' survival probability predictions at `eval_time = 3.0`, we would not know the
166+
#' survival probability predictions at `eval_time = 3.0`, we would _not_ know the
159167
#' about the probability of being censored at that exact time (since it has not
160168
#' occurred yet).
161169
#'
170+
#' When creating weights by inverting probabilities, there is the risk that a few
171+
#' cases will have severe outliers due to probabilities close to zero. To
172+
#' mitigate this, the `trunc` argument can be used to put a cap on the weights.
173+
#' If the smallest probability is greater than `trunc`, the probabilities with
174+
#' values less than `trunc` are given that value. Otherwise, `trunc` is
175+
#' adjusted to be half of the smallest probability and that value is used as the
176+
#' lower bound..
177+
#'
162178
#' Note that if there are `n` rows in `data` and `t` time points, the resulting
163-
#' data has `n * t` rows. Computations will not easily scale well as `t` becomes
164-
#' large.
179+
#' data, once unnested, has `n * t` rows. Computations will not easily scale
180+
#' well as `t` becomes very large.
165181
#' @references Graf, E., Schmoor, C., Sauerbrei, W. and Schumacher, M. (1999),
166182
#' Assessment and comparison of prognostic classification schemes for survival
167183
#' data. _Statist. Med._, 18: 2529-2545.
@@ -185,49 +201,70 @@ graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
185201
#' @export
186202
#' @rdname censoring_weights
187203
.censoring_weights_graf.workflow <- function(object,
188-
data,
189-
eval_time,
190-
rows = NULL,
191-
predictors = NULL,
204+
predictions,
205+
cens_predictors = NULL,
192206
trunc = 0.05, eps = 10^-10, ...) {
193207
if (is.null(object$fit$fit)) {
194-
rlang::abort("The workflow does not have a model fit object.", call = FALSE)
208+
rlang::abort("The workflow does not have a model fit object.")
195209
}
196-
.censoring_weights_graf(object$fit$fit, data, eval_time, rows, predictors, trunc, eps)
210+
.censoring_weights_graf(object$fit$fit, predictions, cens_predictors, trunc, eps)
197211
}
198212

199213
#' @export
200214
#' @rdname censoring_weights
201215
.censoring_weights_graf.model_fit <- function(object,
202-
data,
203-
eval_time,
204-
rows = NULL,
205-
predictors = NULL,
216+
predictions,
217+
cens_predictors = NULL,
206218
trunc = 0.05, eps = 10^-10, ...) {
207219
rlang::check_dots_empty()
208220
.check_censor_model(object)
209-
if (!is.null(predictors)) {
210-
rlang::warn("The 'predictors' argument to the survival weighting function is not currently used.", call = FALSE)
221+
truth <- .find_surv_col(predictions)
222+
.check_censored_right(predictions[[truth]])
223+
.check_pred_col(predictions)
224+
225+
if (!is.null(cens_predictors)) {
226+
msg <- "The 'cens_predictors' argument to the survival weighting function is not currently used."
227+
rlang::warn(msg)
211228
}
212-
eval_time <- .filter_eval_time(eval_time)
229+
predictions$.pred <-
230+
add_graf_weights_vec(object,
231+
predictions$.pred,
232+
predictions[[truth]],
233+
trunc = trunc,
234+
eps = eps)
235+
predictions
236+
}
237+
238+
# ------------------------------------------------------------------------------
239+
# Helpers
240+
241+
add_graf_weights_vec <- function(object, .pred, surv_obj, trunc = 0.05, eps = 10^-10) {
242+
# Expand the list column to one data frame
243+
n <- length(.pred)
244+
num_times <- vctrs::list_sizes(.pred)
245+
y <- vctrs::list_unchop(.pred)
246+
y$surv_obj <- vctrs::vec_rep_each(surv_obj, times = num_times)
247+
names(y)[names(y) == ".time"] <- ".eval_time" # Temporary
248+
# Compute the actual time of evaluation
249+
y$.weight_time <- graf_weight_time_vec(y$surv_obj, y$.eval_time, eps = eps)
250+
# Compute the corresponding probability of being censored
251+
y$.pred_censored <- predict(object$censor_probs, time = y$.weight_time, as_vector = TRUE)
252+
y$.pred_censored <- trunc_probs(y$.pred_censored, trunc = trunc)
253+
# Invert the probabilities to create weights
254+
y$.weight_censored = 1 / y$.pred_censored
255+
# Convert back the list column format
256+
y$surv_obj <- NULL
257+
vctrs::vec_chop(y, sizes = num_times)
258+
}
213259

214-
truth <- object$preproc$y_var
215-
if (length(truth) != 1) {
216-
# check_outcome() tests that the outcome column is a Surv object
217-
rlang::abort("The event time data should be in a single column with class 'Surv'", call = FALSE)
260+
.find_surv_col <- function(x, call = rlang::env_parent()) {
261+
is_lst_col <- purrr::map_lgl(x, purrr::is_list)
262+
is_surv <- purrr::map_lgl(x[!is_lst_col], .is_surv, fail = FALSE)
263+
num_surv <- sum(is_surv)
264+
if (num_surv != 1) {
265+
rlang::abort("There should be a single column of class `Surv`", call = call)
218266
}
219-
surv_data <- dplyr::select(data, dplyr::all_of(!!truth)) %>% setNames("surv")
220-
.check_censored_right(surv_data$surv)
221-
222-
purrr::map(eval_time,
223-
~ graf_weight_time(surv_data$surv, .x, eps = eps, rows = rows)) %>%
224-
purrr::list_rbind() %>%
225-
dplyr::mutate(
226-
.prob_cens = predict(object$censor_probs, time = weight_time, as_vector = TRUE),
227-
.prob_cens = trunc_probs(.prob_cens, trunc),
228-
.weight_cens = 1 / .prob_cens
229-
) %>%
230-
dplyr::select(.row, eval_time, .prob_cens, .weight_cens)
267+
names(is_surv)[is_surv]
231268
}
232269

233270
# nocov end

R/predict.R

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,13 @@
4949
#' ## Censored regression predictions
5050
#'
5151
#' For censored regression, a numeric vector for `eval_time` is required when
52-
#' survival or hazard probabilities are requested. Also, when
53-
#' `type = "linear_pred"`, censored regression models will by default be
54-
#' formatted such that the linear predictor _increases_ with time. This may
52+
#' survival or hazard probabilities are requested. The time values are required
53+
#' to be unique, finite, non-missing, and non-negative. The `predict()`
54+
#' functions will adjust the values to fit this specification by removing
55+
#' offending points (with a warning).
56+
#'
57+
#' Also, when `type = "linear_pred"`, censored regression models will by default
58+
#' be formatted such that the linear predictor _increases_ with time. This may
5559
#' have the opposite sign as what the underlying model's `predict()` method
5660
#' produces. Set `increasing = FALSE` to suppress this behavior.
5761
#'

R/predict_hazard.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ predict_hazard.model_fit <- function(object,
1717
)
1818
eval_time <- time
1919
}
20+
eval_time <- .filter_eval_time(eval_time)
2021

2122
check_spec_pred_type(object, "hazard")
2223

R/predict_survival.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ predict_survival.model_fit <- function(object,
1919
)
2020
eval_time <- time
2121
}
22+
eval_time <- .filter_eval_time(eval_time)
2223

2324
check_spec_pred_type(object, "survival")
2425

0 commit comments

Comments
 (0)