Skip to content

Commit 1c3bca7

Browse files
authored
Merge pull request #192 from tidymodels/more-multi-predict
added a function to get multi_predict arg names
2 parents 2fd29bf + 2f58384 commit 1c3bca7

22 files changed

+232
-123
lines changed

NAMESPACE

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
S3method(fit,model_spec)
44
S3method(fit_xy,model_spec)
5-
S3method(has_multi_pred,default)
6-
S3method(has_multi_pred,model_fit)
7-
S3method(has_multi_pred,workflow)
5+
S3method(has_multi_predict,default)
6+
S3method(has_multi_predict,model_fit)
7+
S3method(has_multi_predict,workflow)
88
S3method(multi_predict,"_C5.0")
99
S3method(multi_predict,"_earth")
1010
S3method(multi_predict,"_elnet")
@@ -13,6 +13,9 @@ S3method(multi_predict,"_multnet")
1313
S3method(multi_predict,"_train.kknn")
1414
S3method(multi_predict,"_xgb.Booster")
1515
S3method(multi_predict,default)
16+
S3method(multi_predict_args,default)
17+
S3method(multi_predict_args,model_fit)
18+
S3method(multi_predict_args,workflow)
1619
S3method(nullmodel,default)
1720
S3method(predict,"_elnet")
1821
S3method(predict,"_lognet")
@@ -95,7 +98,7 @@ export(get_fit)
9598
export(get_from_env)
9699
export(get_model_env)
97100
export(get_pred_type)
98-
export(has_multi_pred)
101+
export(has_multi_predict)
99102
export(keras_mlp)
100103
export(linear_reg)
101104
export(logistic_reg)
@@ -104,6 +107,7 @@ export(mars)
104107
export(mlp)
105108
export(model_printer)
106109
export(multi_predict)
110+
export(multi_predict_args)
107111
export(multinom_reg)
108112
export(nearest_neighbor)
109113
export(null_model)

R/aaa_multi_predict.R

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,102 @@ multi_predict.default <- function(object, ...)
3333
predict.model_spec <- function(object, ...) {
3434
stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE)
3535
}
36+
37+
#' Tools for models that predict on sub-models
38+
#'
39+
#' `has_multi_predict()` tests to see if an object can make multiple
40+
#' predictions on submodels from the same object. `multi_predict_args()`
41+
#' returns the names of the argments to `multi_predict()` for this model
42+
#' (if any).
43+
#' @param object An object to test.
44+
#' @param ... Not currently used.
45+
#' @return `has_multi_predict()` returns single logical value while
46+
#' `multi_predict()` returns a character vector of argument names (or `NA`
47+
#' if none exist).
48+
#' @keywords internal
49+
#' @examples
50+
#' lm_model_idea <- linear_reg() %>% set_engine("lm")
51+
#' has_multi_predict(lm_model_idea)
52+
#' lm_model_fit <- fit(lm_model_idea, mpg ~ ., data = mtcars)
53+
#' has_multi_predict(lm_model_fit)
54+
#'
55+
#' multi_predict_args(lm_model_fit)
56+
#'
57+
#' library(kknn)
58+
#'
59+
#' knn_fit <-
60+
#' nearest_neighbor(mode = "regression", neighbors = 5) %>%
61+
#' set_engine("kknn") %>%
62+
#' fit(mpg ~ ., mtcars)
63+
#'
64+
#' multi_predict_args(knn_fit)
65+
#'
66+
#' multi_predict(knn_fit, mtcars[1, -1], neighbors = 1:4)$.pred
67+
#' @importFrom utils methods
68+
#' @export
69+
has_multi_predict <- function(object, ...) {
70+
UseMethod("has_multi_predict")
71+
}
72+
73+
#' @export
74+
#' @rdname has_multi_predict
75+
has_multi_predict.default <- function(object, ...) {
76+
FALSE
77+
}
78+
79+
#' @export
80+
#' @rdname has_multi_predict
81+
has_multi_predict.model_fit <- function(object, ...) {
82+
existing_mthds <- utils::methods("multi_predict")
83+
tst <- paste0("multi_predict.", class(object))
84+
any(tst %in% existing_mthds)
85+
}
86+
87+
#' @export
88+
#' @rdname has_multi_predict
89+
has_multi_predict.workflow <- function(object, ...) {
90+
has_multi_predict(object$fit$model$model)
91+
}
92+
93+
94+
#' @rdname has_multi_predict
95+
#' @export
96+
#' @rdname has_multi_predict
97+
multi_predict_args <- function(object, ...) {
98+
UseMethod("multi_predict_args")
99+
}
100+
101+
#' @export
102+
#' @rdname has_multi_predict
103+
multi_predict_args.default <- function(object, ...) {
104+
if (inherits(object, "model_fit")) {
105+
res <- multi_predict_args.model_fit(object, ...)
106+
} else {
107+
res <- NA_character_
108+
}
109+
res
110+
}
111+
112+
#' @export
113+
#' @rdname has_multi_predict
114+
multi_predict_args.model_fit <- function(object, ...) {
115+
existing_mthds <- methods("multi_predict")
116+
cls <- class(object)
117+
tst <- paste0("multi_predict.", cls)
118+
.fn <- tst[tst %in% existing_mthds]
119+
if (length(.fn) == 0) {
120+
return(NA_character_)
121+
}
122+
123+
.fn <- getFromNamespace(.fn, ns = "parsnip")
124+
omit <- c('object', 'new_data', 'type', '...')
125+
args <- names(formals(.fn))
126+
args[!(args %in% omit)]
127+
}
128+
129+
#' @export
130+
#' @rdname has_multi_predict
131+
multi_predict_args.workflow <- function(object, ...) {
132+
object <- object$fit$model$model
133+
134+
}

R/nearest_neighbor_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ set_model_arg(
1616
parsnip = "neighbors",
1717
original = "ks",
1818
func = list(pkg = "dials", fun = "neighbors"),
19-
has_submodel = FALSE
19+
has_submodel = TRUE
2020
)
2121
set_model_arg(
2222
model = "nearest_neighbor",

R/predict.R

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -231,73 +231,3 @@ prepare_data <- function(object, new_data) {
231231
new_data
232232
}
233233

234-
# Define a generic to make multiple predictions for the same model object ------
235-
236-
#' Model predictions across many sub-models
237-
#'
238-
#' For some models, predictions can be made on sub-models in the model object.
239-
#' @param object A `model_fit` object.
240-
#' @param ... Optional arguments to pass to `predict.model_fit(type = "raw")`
241-
#' such as `type`.
242-
#' @return A tibble with the same number of rows as the data being predicted.
243-
#' Mostly likely, there is a list-column named `.pred` that is a tibble with
244-
#' multiple rows per sub-model.
245-
#' @export
246-
multi_predict <- function(object, ...) {
247-
if (inherits(object$fit, "try-error")) {
248-
warning("Model fit failed; cannot make predictions.", call. = FALSE)
249-
return(NULL)
250-
}
251-
UseMethod("multi_predict")
252-
}
253-
254-
#' @export
255-
#' @rdname multi_predict
256-
multi_predict.default <- function(object, ...)
257-
stop("No `multi_predict` method exists for objects with classes ",
258-
paste0("'", class(), "'", collapse = ", "), call. = FALSE)
259-
260-
#' @export
261-
predict.model_spec <- function(object, ...) {
262-
stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE)
263-
}
264-
265-
266-
267-
#' Determine if a model can make predictions on sub-models
268-
#'
269-
#' @param object An object to test.
270-
#' @param ... Not currently used.
271-
#' @return A single logical value.
272-
#' @keywords internal
273-
#' @examples
274-
#' model_idea <- linear_reg() %>% set_engine("lm")
275-
#' has_multi_pred(model_idea)
276-
#' model_fit <- fit(model_idea, mpg ~ ., data = mtcars)
277-
#' has_multi_pred(model_fit)
278-
#' @importFrom utils methods
279-
#' @export
280-
has_multi_pred <- function(object, ...) {
281-
UseMethod("has_multi_pred")
282-
}
283-
284-
#' @export
285-
#' @rdname has_multi_pred
286-
has_multi_pred.default <- function(object, ...) {
287-
FALSE
288-
}
289-
290-
#' @export
291-
#' @rdname has_multi_pred
292-
has_multi_pred.model_fit <- function(object, ...) {
293-
existing_mthds <- utils::methods("multi_predict")
294-
tst <- paste0("multi_predict.", class(object))
295-
any(tst %in% existing_mthds)
296-
}
297-
298-
#' @export
299-
#' @rdname has_multi_pred
300-
has_multi_pred.workflow <- function(object, ...) {
301-
has_multi_pred(object$fit$model$model)
302-
}
303-

man/has_multi_pred.Rd

Lines changed: 0 additions & 35 deletions
This file was deleted.

man/has_multi_predict.Rd

Lines changed: 65 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_boost_tree_C50.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ test_that('C5.0 execution', {
4646
regexp = NA
4747
)
4848

49+
expect_true(has_multi_predict(res))
50+
expect_equal(multi_predict_args(res), "trees")
51+
4952
# outcome is not a factor:
5053
expect_error(
5154
res <- fit(

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ test_that('xgboost execution, classification', {
4040
regexp = NA
4141
)
4242

43+
expect_true(has_multi_predict(res))
44+
expect_equal(multi_predict_args(res), "trees")
45+
4346
expect_error(
4447
res <- parsnip::fit(
4548
iris_xgboost,

tests/testthat/test_linear_reg_glmnet.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ test_that('glmnet execution', {
2525
skip_if_not_installed("glmnet")
2626

2727
expect_error(
28-
fit_xy(
28+
res <- fit_xy(
2929
iris_basic,
3030
control = ctrl,
3131
x = iris[, num_pred],
@@ -34,6 +34,9 @@ test_that('glmnet execution', {
3434
regexp = NA
3535
)
3636

37+
expect_true(has_multi_predict(res))
38+
expect_equal(multi_predict_args(res), "penalty")
39+
3740
expect_error(
3841
fit(
3942
iris_basic,

tests/testthat/test_linear_reg_spark.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ test_that('spark execution', {
3636
regexp = NA
3737
)
3838

39+
expect_false(has_multi_predict(spark_fit))
40+
expect_equal(multi_predict_args(spark_fit), NA_character_)
41+
3942
expect_error(
4043
spark_pred <- predict(spark_fit, iris_linreg_te),
4144
regexp = NA

0 commit comments

Comments
 (0)