Skip to content

Commit 2bb8b2d

Browse files
committed
polished/updated parametric survival models.
1 parent dfd8d59 commit 2bb8b2d

File tree

13 files changed

+477
-69
lines changed

13 files changed

+477
-69
lines changed

R/surv_reg.R

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,39 @@
2525
#' `strata` function cannot be used. To achieve the same effect,
2626
#' the extra parameter roles can be used (as described above).
2727
#'
28-
#' The model can be created using the `fit()` function using the
29-
#' following _engines_:
30-
#' \itemize{
31-
#' \item \pkg{R}: `"flexsurv"`
32-
#' }
3328
#' @inheritParams boost_tree
3429
#' @param mode A single character string for the type of model.
3530
#' The only possible value for this model is "regression".
3631
#' @param dist A character string for the outcome distribution. "weibull" is
3732
#' the default.
33+
#' @details
34+
#' For `surv_reg`, the mode will always be "regression".
35+
#'
36+
#' The model can be created using the `fit()` function using the
37+
#' following _engines_:
38+
#' \itemize{
39+
#' \item \pkg{R}: `"flexsurv"`, `"survreg"`
40+
#' }
41+
#'
42+
#' @section Engine Details:
43+
#'
44+
#' Engines may have pre-set default arguments when executing the
45+
#' model fit call. These can be changed by using the `...`
46+
#' argument to pass in the preferred values. For this type of
47+
#' model, the template of the fit calls are:
48+
#'
49+
#' \pkg{flexsurv}
50+
#'
51+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "flexsurv")}
52+
#'
53+
#' \pkg{survreg}
54+
#'
55+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "survreg")}
56+
#'
57+
#' Note that `model = TRUE` is needed to produce quantile
58+
#' predictions when there is a stratification variable and can be
59+
#' overridden in other cases.
60+
#'
3861
#' @seealso [varying()], [fit()], [survival::Surv()]
3962
#' @references Jackson, C. (2016). `flexsurv`: A Platform for Parametric Survival
4063
#' Modeling in R. _Journal of Statistical Software_, 70(8), 1 - 33.
@@ -160,3 +183,51 @@ check_args.surv_reg <- function(object) {
160183

161184
invisible(object)
162185
}
186+
187+
# ------------------------------------------------------------------------------
188+
189+
#' @importFrom stats setNames
190+
#' @importFrom dplyr mutate
191+
survreg_quant <- function(results, object) {
192+
pctl <- object$spec$method$quantile$args$p
193+
n <- nrow(results)
194+
p <- ncol(results)
195+
results <-
196+
results %>%
197+
as_tibble() %>%
198+
setNames(names0(p)) %>%
199+
mutate(.row = 1:n) %>%
200+
gather(.label, .pred, -.row) %>%
201+
arrange(.row, .label) %>%
202+
mutate(.quantile = rep(pctl, n)) %>%
203+
dplyr::select(-.label)
204+
.row <- results[[".row"]]
205+
results <-
206+
results %>%
207+
dplyr::select(-.row)
208+
results <- split(results, .row)
209+
names(results) <- NULL
210+
tibble(.pred = results)
211+
}
212+
213+
# ------------------------------------------------------------------------------
214+
215+
#' @importFrom dplyr bind_rows
216+
flexsurv_mean <- function(results, object) {
217+
results <- unclass(results)
218+
results <- bind_rows(results)
219+
results$est
220+
}
221+
222+
#' @importFrom stats setNames
223+
flexsurv_quant <- function(results, object) {
224+
results <- map(results, as_tibble)
225+
names(results) <- NULL
226+
results <- map(results, setNames, c(".quantile", ".pred", ".pred_lower", ".pred_upper"))
227+
}
228+
229+
# ------------------------------------------------------------------------------
230+
231+
#' @importFrom utils globalVariables
232+
utils::globalVariables(".label")
233+

R/surv_reg_data.R

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11

22
surv_reg_arg_key <- data.frame(
3-
flexsurv = c("dist"),
3+
flexsurv = c("dist"),
4+
survreg = c("dist"),
5+
stan = c("family"),
46
stringsAsFactors = FALSE,
57
row.names = c("dist")
68
)
@@ -9,6 +11,8 @@ surv_reg_modes <- "regression"
911

1012
surv_reg_engines <- data.frame(
1113
flexsurv = TRUE,
14+
survreg = TRUE,
15+
stan = TRUE,
1216
stringsAsFactors = TRUE,
1317
row.names = c("regression")
1418
)
@@ -23,5 +27,96 @@ surv_reg_flexsurv_data <-
2327
protect = c("formula", "data", "weights"),
2428
func = c(pkg = "flexsurv", fun = "flexsurvreg"),
2529
defaults = list()
30+
),
31+
pred = list(
32+
pre = NULL,
33+
post = flexsurv_mean,
34+
func = c(fun = "summary"),
35+
args =
36+
list(
37+
object = expr(object$fit),
38+
newdata = expr(new_data),
39+
type = "mean"
40+
)
41+
),
42+
quantile = list(
43+
pre = NULL,
44+
post = flexsurv_quant,
45+
func = c(fun = "summary"),
46+
args =
47+
list(
48+
object = expr(object$fit),
49+
newdata = expr(new_data),
50+
type = "quantile",
51+
quantiles = expr(quantile)
52+
)
2653
)
2754
)
55+
56+
# ------------------------------------------------------------------------------
57+
58+
surv_reg_survreg_data <-
59+
list(
60+
libs = c("survival"),
61+
fit = list(
62+
interface = "formula",
63+
protect = c("formula", "data", "weights"),
64+
func = c(pkg = "survival", fun = "survreg"),
65+
defaults = list(model = TRUE)
66+
),
67+
pred = list(
68+
pre = NULL,
69+
post = NULL,
70+
func = c(fun = "predict"),
71+
args =
72+
list(
73+
object = expr(object$fit),
74+
newdata = expr(new_data),
75+
type = "response"
76+
)
77+
),
78+
quantile = list(
79+
pre = NULL,
80+
post = survreg_quant,
81+
func = c(fun = "predict"),
82+
args =
83+
list(
84+
object = expr(object$fit),
85+
newdata = expr(new_data),
86+
type = "quantile",
87+
p = expr(quantile)
88+
)
89+
)
90+
)
91+
92+
# ------------------------------------------------------------------------------
93+
94+
surv_reg_stan_data <-
95+
list(
96+
libs = c("brms"),
97+
fit = list(
98+
interface = "formula",
99+
protect = c("formula", "data", "weights"),
100+
func = c(pkg = "brms", fun = "brm"),
101+
defaults = list(
102+
family = expr(brms::weibull()),
103+
seed = expr(sample.int(10^5, 1))
104+
)
105+
),
106+
pred = list(
107+
pre = NULL,
108+
post = function(results, object) {
109+
tibble::as_tibble(results) %>%
110+
dplyr::select(Estimate) %>%
111+
setNames(".pred")
112+
},
113+
func = c(fun = "predict"),
114+
args =
115+
list(
116+
object = expr(object$fit),
117+
newdata = expr(new_data),
118+
type = "response"
119+
)
120+
)
121+
)
122+

docs/articles/articles/Classification.html

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)