Skip to content

Commit 4081293

Browse files
authored
Merge pull request #97 from topepo/quosure-passthrough-tests
Misc updates
2 parents a497d07 + 0f406df commit 4081293

26 files changed

+619
-97
lines changed

.travis.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,15 @@ r:
1515
- devel
1616

1717
env:
18-
- KERAS_BACKEND="tensorflow"
1918
global:
20-
- MAKEFLAGS="-j 2"
19+
- KERAS_BACKEND="tensorflow"
20+
- MAKEFLAGS="-j 2"
21+
22+
# until we troubleshoot these issues
23+
matrix:
24+
allow_failures:
25+
- r: 3.1
26+
- r: 3.2
2127

2228
r_binary_packages:
2329
- rstan

NAMESPACE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ S3method(predict_confint,model_fit)
2222
S3method(predict_num,"_elnet")
2323
S3method(predict_num,model_fit)
2424
S3method(predict_predint,model_fit)
25+
S3method(predict_quantile,model_fit)
2526
S3method(predict_raw,"_elnet")
2627
S3method(predict_raw,"_lognet")
2728
S3method(predict_raw,"_multnet")
@@ -95,6 +96,8 @@ export(predict_num)
9596
export(predict_num.model_fit)
9697
export(predict_predint)
9798
export(predict_predint.model_fit)
99+
export(predict_quantile)
100+
export(predict_quantile.model_fit)
98101
export(predict_raw)
99102
export(predict_raw.model_fit)
100103
export(rand_forest)
@@ -113,10 +116,12 @@ import(rlang)
113116
importFrom(dplyr,arrange)
114117
importFrom(dplyr,as_tibble)
115118
importFrom(dplyr,bind_cols)
119+
importFrom(dplyr,bind_rows)
116120
importFrom(dplyr,collect)
117121
importFrom(dplyr,full_join)
118122
importFrom(dplyr,funs)
119123
importFrom(dplyr,group_by)
124+
importFrom(dplyr,mutate)
120125
importFrom(dplyr,pull)
121126
importFrom(dplyr,rename)
122127
importFrom(dplyr,rename_at)
@@ -159,6 +164,7 @@ importFrom(stats,predict)
159164
importFrom(stats,qnorm)
160165
importFrom(stats,qt)
161166
importFrom(stats,quantile)
167+
importFrom(stats,setNames)
162168
importFrom(stats,terms)
163169
importFrom(stats,update)
164170
importFrom(tibble,as_tibble)

R/aaa_spark_helpers.R

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
#' @importFrom dplyr starts_with rename rename_at vars funs
44
format_spark_probs <- function(results, object) {
55
results <- dplyr::select(results, starts_with("probability_"))
6-
results <- dplyr::rename_at(
7-
results,
8-
vars(starts_with("probability_")),
9-
funs(gsub("probability", "pred", .))
10-
)
11-
results
6+
p <- ncol(results)
7+
lvl <- paste0("probability_", 0:(p - 1))
8+
names(lvl) <- paste0("pred_", object$fit$.index_labels)
9+
results %>% rename(!!!syms(lvl))
1210
}
1311

1412
format_spark_class <- function(results, object) {

R/fit_helpers.R

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@ form_form <-
88
function(object, control, env, ...) {
99
opts <- quos(...)
1010

11-
y_levels <- levels_from_formula( # prob rewrite this as simple subset/levels
12-
env$formula,
13-
env$data
14-
)
11+
if (object$mode != "regression") {
12+
y_levels <- levels_from_formula( # prob rewrite this as simple subset/levels
13+
env$formula,
14+
env$data
15+
)
16+
} else {
17+
y_levels <- NULL
18+
}
1519

1620
object <- check_mode(object, y_levels)
1721

R/misc.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,15 @@ check_args <- function(object) {
178178
check_args.default <- function(object) {
179179
invisible(object)
180180
}
181+
182+
# ------------------------------------------------------------------------------
183+
184+
# copied form recipes
185+
186+
names0 <- function (num, prefix = "x") {
187+
if (num < 1)
188+
stop("`num` should be > 0", call. = FALSE)
189+
ind <- format(1:num)
190+
ind <- gsub(" ", "0", ind)
191+
paste0(prefix, ind)
192+
}

R/predict.R

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
#' @param object An object of class `model_fit`
88
#' @param new_data A rectangular data object, such as a data frame.
99
#' @param type A single character value or `NULL`. Possible values
10-
#' are "numeric", "class", "probs", "conf_int", "pred_int", or
11-
#' "raw". When `NULL`, `predict` will choose an appropriate value
10+
#' are "numeric", "class", "probs", "conf_int", "pred_int", "quantile",
11+
#' or "raw". When `NULL`, `predict` will choose an appropriate value
1212
#' based on the model's mode.
1313
#' @param opts A list of optional arguments to the underlying
1414
#' predict function that will be used when `type = "raw"`. The
@@ -45,6 +45,10 @@
4545
#' produces for class probabilities (or other non-scalar outputs),
4646
#' the columns will be named `.pred_lower_classlevel` and so on.
4747
#'
48+
#' Quantile predictions return a tibble with a column `.pred`, which is
49+
#' a list-column. Each list element contains a tibble with columns
50+
#' `.pred` and `.quantile` (and perhaps others).
51+
#'
4852
#' Using `type = "raw"` with `predict.model_fit` (or using
4953
#' `predict_raw`) will return the unadulterated results of the
5054
#' prediction function.
@@ -96,6 +100,7 @@ predict.model_fit <- function (object, new_data, type = NULL, opts = list(), ...
96100
prob = predict_classprob(object = object, new_data = new_data, ...),
97101
conf_int = predict_confint(object = object, new_data = new_data, ...),
98102
pred_int = predict_predint(object = object, new_data = new_data, ...),
103+
quantile = predict_quantile(object = object, new_data = new_data, ...),
99104
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
100105
stop("I don't know about type = '", "'", type, call. = FALSE)
101106
)
@@ -112,7 +117,8 @@ predict.model_fit <- function (object, new_data, type = NULL, opts = list(), ...
112117
res
113118
}
114119

115-
pred_types <- c("raw", "numeric", "class", "link", "prob", "conf_int", "pred_int")
120+
pred_types <-
121+
c("raw", "numeric", "class", "link", "prob", "conf_int", "pred_int", "quantile")
116122

117123
#' @importFrom glue glue_collapse
118124
check_pred_type <- function(object, type) {

R/predict_quantile.R

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#' @keywords internal
2+
#' @rdname other_predict
3+
#' @param quant A vector of numbers between 0 and 1 for the quantile being
4+
#' predicted.
5+
#' @inheritParams predict.model_fit
6+
#' @method predict_quantile model_fit
7+
#' @export predict_quantile.model_fit
8+
#' @export
9+
predict_quantile.model_fit <-
10+
function (object, new_data, quantile = (1:9)/10, ...) {
11+
12+
if (is.null(object$spec$method$quantile))
13+
stop("No quantile prediction method defined for this ",
14+
"engine.", call. = FALSE)
15+
16+
new_data <- prepare_data(object, new_data)
17+
18+
# preprocess data
19+
if (!is.null(object$spec$method$quantile$pre))
20+
new_data <- object$spec$method$quantile$pre(new_data, object)
21+
22+
# Pass some extra arguments to be used in post-processor
23+
object$spec$method$quantile$args$p <- quantile
24+
pred_call <- make_pred_call(object$spec$method$quantile)
25+
26+
res <- eval_tidy(pred_call)
27+
28+
# post-process the predictions
29+
if(!is.null(object$spec$method$quantile$post)) {
30+
res <- object$spec$method$quantile$post(res, object)
31+
}
32+
33+
res
34+
}
35+
36+
#' @export
37+
#' @keywords internal
38+
#' @rdname other_predict
39+
#' @inheritParams predict.model_fit
40+
predict_quantile <- function (object, ...)
41+
UseMethod("predict_quantile")

R/surv_reg.R

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,39 @@
2525
#' `strata` function cannot be used. To achieve the same effect,
2626
#' the extra parameter roles can be used (as described above).
2727
#'
28-
#' The model can be created using the `fit()` function using the
29-
#' following _engines_:
30-
#' \itemize{
31-
#' \item \pkg{R}: `"flexsurv"`
32-
#' }
3328
#' @inheritParams boost_tree
3429
#' @param mode A single character string for the type of model.
3530
#' The only possible value for this model is "regression".
3631
#' @param dist A character string for the outcome distribution. "weibull" is
3732
#' the default.
33+
#' @details
34+
#' For `surv_reg`, the mode will always be "regression".
35+
#'
36+
#' The model can be created using the `fit()` function using the
37+
#' following _engines_:
38+
#' \itemize{
39+
#' \item \pkg{R}: `"flexsurv"`, `"survreg"`
40+
#' }
41+
#'
42+
#' @section Engine Details:
43+
#'
44+
#' Engines may have pre-set default arguments when executing the
45+
#' model fit call. These can be changed by using the `...`
46+
#' argument to pass in the preferred values. For this type of
47+
#' model, the template of the fit calls are:
48+
#'
49+
#' \pkg{flexsurv}
50+
#'
51+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "flexsurv")}
52+
#'
53+
#' \pkg{survreg}
54+
#'
55+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "survreg")}
56+
#'
57+
#' Note that `model = TRUE` is needed to produce quantile
58+
#' predictions when there is a stratification variable and can be
59+
#' overridden in other cases.
60+
#'
3861
#' @seealso [varying()], [fit()], [survival::Surv()]
3962
#' @references Jackson, C. (2016). `flexsurv`: A Platform for Parametric Survival
4063
#' Modeling in R. _Journal of Statistical Software_, 70(8), 1 - 33.
@@ -160,3 +183,51 @@ check_args.surv_reg <- function(object) {
160183

161184
invisible(object)
162185
}
186+
187+
# ------------------------------------------------------------------------------
188+
189+
#' @importFrom stats setNames
190+
#' @importFrom dplyr mutate
191+
survreg_quant <- function(results, object) {
192+
pctl <- object$spec$method$quantile$args$p
193+
n <- nrow(results)
194+
p <- ncol(results)
195+
results <-
196+
results %>%
197+
as_tibble() %>%
198+
setNames(names0(p)) %>%
199+
mutate(.row = 1:n) %>%
200+
gather(.label, .pred, -.row) %>%
201+
arrange(.row, .label) %>%
202+
mutate(.quantile = rep(pctl, n)) %>%
203+
dplyr::select(-.label)
204+
.row <- results[[".row"]]
205+
results <-
206+
results %>%
207+
dplyr::select(-.row)
208+
results <- split(results, .row)
209+
names(results) <- NULL
210+
tibble(.pred = results)
211+
}
212+
213+
# ------------------------------------------------------------------------------
214+
215+
#' @importFrom dplyr bind_rows
216+
flexsurv_mean <- function(results, object) {
217+
results <- unclass(results)
218+
results <- bind_rows(results)
219+
results$est
220+
}
221+
222+
#' @importFrom stats setNames
223+
flexsurv_quant <- function(results, object) {
224+
results <- map(results, as_tibble)
225+
names(results) <- NULL
226+
results <- map(results, setNames, c(".quantile", ".pred", ".pred_lower", ".pred_upper"))
227+
}
228+
229+
# ------------------------------------------------------------------------------
230+
231+
#' @importFrom utils globalVariables
232+
utils::globalVariables(".label")
233+

0 commit comments

Comments
 (0)