@@ -22,36 +22,41 @@ trunc_probs <- function(probs, trunc = 0.01) {
2222 if (! is.null(eval_time )) {
2323 eval_time <- as.numeric(eval_time )
2424 }
25+ eval_time_0 <- eval_time
2526 # will still propagate nulls:
2627 eval_time <- eval_time [! is.na(eval_time )]
27- eval_time <- unique(eval_time )
28- eval_time <- sort(eval_time )
2928 eval_time <- eval_time [eval_time > = 0 & is.finite(eval_time )]
29+ eval_time <- unique(eval_time )
3030 if (fail && identical(eval_time , numeric (0 ))) {
3131 rlang :: abort(
3232 " There were no usable evaluation times (finite, non-missing, and >= 0)." ,
3333 call = NULL
3434 )
3535 }
36+ if (! identical(eval_time , eval_time_0 )) {
37+ diffs <- setdiff(eval_time_0 , eval_time )
38+ msg <-
39+ cli :: pluralize(
40+ " There {?was/were} {length(diffs)} inappropriate evaluation time point{?s} that {?was/were} removed." )
41+ rlang :: warn(msg )
42+ }
3643 eval_time
3744}
3845
39- add_dot_row_to_weights <- function (dat , rows = NULL ) {
40- if (is.null(rows )) {
41- dat <- add_rowindex(dat )
42- } else {
43- m <- length(rows )
44- n <- nrow(dat )
45- if (m != n ) {
46- rlang :: abort(
47- glue :: glue(
48- " The length of 'rows' ({m}) should be equal to the number of rows in 'data' ({n})"
49- )
50- )
51- }
52- dat $ .row <- rows
46+ .check_pred_col <- function (x , call = rlang :: env_parent()) {
47+ if (! any(names(x ) == " .pred" )) {
48+ rlang :: abort(" The input should have a list column called `.pred`." , call = call )
49+ }
50+ if (! is.list(x $ .pred )) {
51+ rlang :: abort(" The input should have a list column called `.pred`." , call = call )
5352 }
54- dat
53+ req_cols <- c(" .eval_time" , " .pred_survival" )
54+ if (! all(req_cols %in% names(x $ .pred [[1 ]]))) {
55+ msg <- paste0(" The `.pred` tibbles should have columns: " ,
56+ paste0(" '" , req_cols , " '" , collapse = " , " ))
57+ rlang :: abort(msg , call = call )
58+ }
59+ invisible (NULL )
5560}
5661
5762.check_censor_model <- function (x ) {
@@ -73,7 +78,7 @@ add_dot_row_to_weights <- function(dat, rows = NULL) {
7378# We need to use the time of analysis to determine what time to use to evaluate
7479# the IPCWs.
7580
76- graf_weight_time <- function (surv_obj , eval_time , rows = NULL , eps = 10 ^- 10 ) {
81+ graf_weight_time_vec <- function (surv_obj , eval_time , eps = 10 ^- 10 ) {
7782 event_time <- .extract_surv_time(surv_obj )
7883 status <- .extract_surv_status(surv_obj )
7984 is_event_before_t <- event_time < = eval_time & status == 1
@@ -85,15 +90,14 @@ graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
8590 weight_time <- rep(NA_real_ , length(event_time ))
8691
8792 # A real event prior to eval_time (Graf category 1)
88- weight_time [ is_event_before_t ] <- event_time [ is_event_before_t ] - eps
93+ weight_time <- ifelse( is_event_before_t , event_time - eps , weight_time )
8994
9095 # Observed time greater than eval_time (Graf category 2)
91- weight_time [ is_censored ] <- eval_time - eps
96+ weight_time <- ifelse( is_censored , eval_time - eps , weight_time )
9297
9398 weight_time <- ifelse(weight_time < 0 , 0 , weight_time )
9499
95- res <- tibble :: tibble(surv = surv_obj , weight_time = weight_time , eval_time )
96- add_dot_row_to_weights(res , rows )
100+ weight_time
97101}
98102
99103# ------------------------------------------------------------------------------
@@ -102,24 +106,28 @@ graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
102106# ' The method of Graf _et al_ (1999) is used to compute weights at specific
103107# ' evaluation times that can be used to help measure a model's time-dependent
104108# ' performance (e.g. the time-dependent Brier score or the area under the ROC
105- # ' curve).
106- # ' @param data A data frame with a column containing a [survival::Surv()] object.
107- # ' @param predictors Not currently used. A potential future slot for models with
108- # ' informative censoring based on columns in `data`.
109- # ' @param rows An optional integer vector with length equal to the number of
110- # ' rows in `data` that is used to index the original data. The default is to
111- # ' use a fresh index on data (i.e. `1:nrow(data)`).
112- # ' @param eval_time A vector of finite, non-negative times at which to
113- # ' compute the probability of censoring and the corresponding weights.
109+ # ' curve). This is an internal function.
110+ # '
111+ # ' @param predictions A data frame with a column containing a [survival::Surv()]
112+ # ' object as well as a list column called `.pred` that contains the data
113+ # ' structure produced by [predict.model_fit()].
114+ # ' @param cens_predictors Not currently used. A potential future slot for models with
115+ # ' informative censoring based on columns in `predictions`.
114116# ' @param object A fitted parsnip model object or fitted workflow with a mode
115117# ' of "censored regression".
116118# ' @param trunc A potential lower bound for the probability of censoring to avoid
117119# ' very large weight values.
118120# ' @param eps A small value that is subtracted from the evaluation time when
119121# ' computing the censoring probabilities. See Details below.
120- # ' @return A tibble with columns `.row`, `eval_time`, `.prob_cens` (the
121- # ' probability of being censored just prior to the evaluation time), and
122- # ' `.weight_cens` (the inverse probability of censoring weight).
122+ # ' @return The same data are returned with the `pred` tibbles containing
123+ # ' several new columns:
124+ # '
125+ # ' - `.weight_time`: the time at which the inverse censoring probability weights
126+ # ' are computed. This is a function of the observed time and the time of
127+ # ' analysis (i.e., `eval_time`). See Details for more information.
128+ # ' - `.pred_censored`: the probability of being censored at `.weight_time`.
129+ # ' - `.weight_censored`: The inverse of the censoring probability.
130+ # '
123131# ' @details
124132# '
125133# ' A probability that the data are censored immediately prior to a specific
@@ -155,13 +163,21 @@ graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
155163# ' The `eps` argument is used to avoid information leakage when computing the
156164# ' censoring probability. Subtracting a small number avoids using data that
157165# ' would not be known at the time of prediction. For example, if we are making
158- # ' survival probability predictions at `eval_time = 3.0`, we would not know the
166+ # ' survival probability predictions at `eval_time = 3.0`, we would _not_ know the
159167# ' about the probability of being censored at that exact time (since it has not
160168# ' occurred yet).
161169# '
170+ # ' When creating weights by inverting probabilities, there is the risk that a few
171+ # ' cases will have severe outliers due to probabilities close to zero. To
172+ # ' mitigate this, the `trunc` argument can be used to put a cap on the weights.
173+ # ' If the smallest probability is greater than `trunc`, the probabilities with
174+ # ' values less than `trunc` are given that value. Otherwise, `trunc` is
175+ # ' adjusted to be half of the smallest probability and that value is used as the
176+ # ' lower bound..
177+ # '
162178# ' Note that if there are `n` rows in `data` and `t` time points, the resulting
163- # ' data has `n * t` rows. Computations will not easily scale well as `t` becomes
164- # ' large.
179+ # ' data, once unnested, has `n * t` rows. Computations will not easily scale
180+ # ' well as `t` becomes very large.
165181# ' @references Graf, E., Schmoor, C., Sauerbrei, W. and Schumacher, M. (1999),
166182# ' Assessment and comparison of prognostic classification schemes for survival
167183# ' data. _Statist. Med._, 18: 2529-2545.
@@ -185,49 +201,70 @@ graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
185201# ' @export
186202# ' @rdname censoring_weights
187203.censoring_weights_graf.workflow <- function (object ,
188- data ,
189- eval_time ,
190- rows = NULL ,
191- predictors = NULL ,
204+ predictions ,
205+ cens_predictors = NULL ,
192206 trunc = 0.05 , eps = 10 ^- 10 , ... ) {
193207 if (is.null(object $ fit $ fit )) {
194- rlang :: abort(" The workflow does not have a model fit object." , call = FALSE )
208+ rlang :: abort(" The workflow does not have a model fit object." )
195209 }
196- .censoring_weights_graf(object $ fit $ fit , data , eval_time , rows , predictors , trunc , eps )
210+ .censoring_weights_graf(object $ fit $ fit , predictions , cens_predictors , trunc , eps )
197211}
198212
199213# ' @export
200214# ' @rdname censoring_weights
201215.censoring_weights_graf.model_fit <- function (object ,
202- data ,
203- eval_time ,
204- rows = NULL ,
205- predictors = NULL ,
216+ predictions ,
217+ cens_predictors = NULL ,
206218 trunc = 0.05 , eps = 10 ^- 10 , ... ) {
207219 rlang :: check_dots_empty()
208220 .check_censor_model(object )
209- if (! is.null(predictors )) {
210- rlang :: warn(" The 'predictors' argument to the survival weighting function is not currently used." , call = FALSE )
221+ truth <- .find_surv_col(predictions )
222+ .check_censored_right(predictions [[truth ]])
223+ .check_pred_col(predictions )
224+
225+ if (! is.null(cens_predictors )) {
226+ msg <- " The 'cens_predictors' argument to the survival weighting function is not currently used."
227+ rlang :: warn(msg )
211228 }
212- eval_time <- .filter_eval_time(eval_time )
229+ predictions $ .pred <-
230+ add_graf_weights_vec(object ,
231+ predictions $ .pred ,
232+ predictions [[truth ]],
233+ trunc = trunc ,
234+ eps = eps )
235+ predictions
236+ }
237+
238+ # ------------------------------------------------------------------------------
239+ # Helpers
240+
241+ add_graf_weights_vec <- function (object , .pred , surv_obj , trunc = 0.05 , eps = 10 ^- 10 ) {
242+ # Expand the list column to one data frame
243+ n <- length(.pred )
244+ num_times <- vctrs :: list_sizes(.pred )
245+ y <- vctrs :: list_unchop(.pred )
246+ y $ surv_obj <- vctrs :: vec_rep_each(surv_obj , times = num_times )
247+ names(y )[names(y ) == " .time" ] <- " .eval_time" # Temporary
248+ # Compute the actual time of evaluation
249+ y $ .weight_time <- graf_weight_time_vec(y $ surv_obj , y $ .eval_time , eps = eps )
250+ # Compute the corresponding probability of being censored
251+ y $ .pred_censored <- predict(object $ censor_probs , time = y $ .weight_time , as_vector = TRUE )
252+ y $ .pred_censored <- trunc_probs(y $ .pred_censored , trunc = trunc )
253+ # Invert the probabilities to create weights
254+ y $ .weight_censored = 1 / y $ .pred_censored
255+ # Convert back the list column format
256+ y $ surv_obj <- NULL
257+ vctrs :: vec_chop(y , sizes = num_times )
258+ }
213259
214- truth <- object $ preproc $ y_var
215- if (length(truth ) != 1 ) {
216- # check_outcome() tests that the outcome column is a Surv object
217- rlang :: abort(" The event time data should be in a single column with class 'Surv'" , call = FALSE )
260+ .find_surv_col <- function (x , call = rlang :: env_parent()) {
261+ is_lst_col <- purrr :: map_lgl(x , purrr :: is_list )
262+ is_surv <- purrr :: map_lgl(x [! is_lst_col ], .is_surv , fail = FALSE )
263+ num_surv <- sum(is_surv )
264+ if (num_surv != 1 ) {
265+ rlang :: abort(" There should be a single column of class `Surv`" , call = call )
218266 }
219- surv_data <- dplyr :: select(data , dplyr :: all_of(!! truth )) %> % setNames(" surv" )
220- .check_censored_right(surv_data $ surv )
221-
222- purrr :: map(eval_time ,
223- ~ graf_weight_time(surv_data $ surv , .x , eps = eps , rows = rows )) %> %
224- purrr :: list_rbind() %> %
225- dplyr :: mutate(
226- .prob_cens = predict(object $ censor_probs , time = weight_time , as_vector = TRUE ),
227- .prob_cens = trunc_probs(.prob_cens , trunc ),
228- .weight_cens = 1 / .prob_cens
229- ) %> %
230- dplyr :: select(.row , eval_time , .prob_cens , .weight_cens )
267+ names(is_surv )[is_surv ]
231268}
232269
233270# nocov end
0 commit comments