|
| 1 | +# nocov start |
| 2 | +# tested in tidymodels/extratests#67 |
| 3 | + |
| 4 | +new_reverse_km_fit <- |
| 5 | + function(formula, |
| 6 | + object, |
| 7 | + pkgs = character(0), |
| 8 | + label = character(0), |
| 9 | + extra_cls = character(0)) { |
| 10 | + res <- list(formula = formula, fit = object, label = label, required_pkgs = pkgs) |
| 11 | + class(res) <- c(paste0("censoring_model_", label), "censoring_model", extra_cls) |
| 12 | + res |
| 13 | + } |
| 14 | + |
| 15 | +# ------------------------------------------------------------------------------ |
| 16 | +# estimate the reverse km curve for censored regression models |
| 17 | + |
| 18 | +reverse_km <- function(obj, eval_env) { |
| 19 | + if (obj$mode != "censored regression") { |
| 20 | + return(list()) |
| 21 | + } |
| 22 | + rlang::check_installed("prodlim") |
| 23 | + |
| 24 | + # Note: even when fit_xy() is called, eval_env will still have |
| 25 | + # objects data and formula in them |
| 26 | + f <- eval_env$formula |
| 27 | + km_form <- stats::update(f, ~ 1) |
| 28 | + cl <- |
| 29 | + rlang::call2( |
| 30 | + "prodlim", |
| 31 | + formula = km_form, |
| 32 | + .ns = "prodlim", |
| 33 | + reverse = TRUE, |
| 34 | + type = "surv", |
| 35 | + x = FALSE, |
| 36 | + data = rlang::expr(eval_env$data) |
| 37 | + ) |
| 38 | + |
| 39 | + if (!is.null(eval_env$weights)) { |
| 40 | + cl <- rlang::call_modify(cl, caseweights = rlang::expr(eval_env$weights)) |
| 41 | + } |
| 42 | + rkm <- try(rlang::eval_tidy(cl), silent = TRUE) |
| 43 | + new_reverse_km_fit(f, object = rkm, label = "reverse_km", pkgs = "prodlim") |
| 44 | +} |
| 45 | + |
| 46 | +# ------------------------------------------------------------------------------ |
| 47 | +# Basic S3 methods |
| 48 | + |
| 49 | +print.censoring_model <- function(x, ...) { |
| 50 | + cat(x$label, "model for predicting the probability of censoring\n") |
| 51 | + invisible(x) |
| 52 | +} |
| 53 | + |
| 54 | +predict.censoring_model <- function(object, ...) { |
| 55 | + rlang::abort( |
| 56 | + paste("Don't know how to predict with a censoring model of type:", object$label) |
| 57 | + ) |
| 58 | + invisible(NULL) |
| 59 | +} |
| 60 | + |
| 61 | +#' @export |
| 62 | +predict.censoring_model_reverse_km <- function(object, new_data = NULL, time, as_vector = FALSE, ...) { |
| 63 | + rlang::check_installed("prodlim") |
| 64 | + |
| 65 | + res <- rep(NA_real_, length(time)) |
| 66 | + if (length(time) == 0) { |
| 67 | + return(res) |
| 68 | + } |
| 69 | + |
| 70 | + # Some time values might be NA (for Graf category 2) |
| 71 | + is_na <- which(is.na(time)) |
| 72 | + if (length(is_na) > 0) { |
| 73 | + time <- time[-is_na] |
| 74 | + } |
| 75 | + |
| 76 | + if (is.null(new_data)) { |
| 77 | + tmp <- |
| 78 | + purrr::map_dbl(time, ~ predict(object$fit, times = .x, type = "surv")) |
| 79 | + } else { |
| 80 | + tmp <- |
| 81 | + purrr::map_dbl(time, ~ predict(object$fit, newdata = new_data, times = .x, type = "surv")) |
| 82 | + } |
| 83 | + |
| 84 | + zero_prob <- purrr::map_lgl(tmp, ~ !is.na(.x) && .x == 0) |
| 85 | + if (any(zero_prob)) { |
| 86 | + # Don't want censoring probabilities of zero so add an epsilon |
| 87 | + # Either use 1/n or half of the minimum survival probability |
| 88 | + n <- max(object$fit$n.risk) |
| 89 | + half_min_surv_prob <- min(object$fit$surv[object$fit$surv > 0]) / 2 |
| 90 | + eps <- min(1 / n, half_min_surv_prob) |
| 91 | + tmp[zero_prob] <- eps |
| 92 | + } |
| 93 | + |
| 94 | + if (length(is_na) > 0) { |
| 95 | + res[-is_na] <- tmp |
| 96 | + } else { |
| 97 | + res <- tmp |
| 98 | + } |
| 99 | + |
| 100 | + if (!as_vector) { |
| 101 | + res <- tibble::tibble(.prob_censored = unname(res)) |
| 102 | + } |
| 103 | + res |
| 104 | +} |
| 105 | + |
| 106 | +# nocov end |
0 commit comments