Skip to content

Commit 7829acb

Browse files
authored
Merge pull request #189 from tidymodels/has-multi-predict
generic and methods for has_multi_predict
2 parents 90a2a3b + 35817ad commit 7829acb

File tree

4 files changed

+112
-0
lines changed

4 files changed

+112
-0
lines changed

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
S3method(fit,model_spec)
44
S3method(fit_xy,model_spec)
5+
S3method(has_multi_pred,default)
6+
S3method(has_multi_pred,model_fit)
7+
S3method(has_multi_pred,workflow)
58
S3method(multi_predict,"_C5.0")
69
S3method(multi_predict,"_earth")
710
S3method(multi_predict,"_elnet")
@@ -91,6 +94,7 @@ export(get_fit)
9194
export(get_from_env)
9295
export(get_model_env)
9396
export(get_pred_type)
97+
export(has_multi_pred)
9498
export(keras_mlp)
9599
export(linear_reg)
96100
export(logistic_reg)
@@ -210,4 +214,5 @@ importFrom(utils,capture.output)
210214
importFrom(utils,getFromNamespace)
211215
importFrom(utils,globalVariables)
212216
importFrom(utils,head)
217+
importFrom(utils,methods)
213218
importFrom(vctrs,vec_unique)

R/predict.R

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,44 @@ multi_predict.default <- function(object, ...)
261261
predict.model_spec <- function(object, ...) {
262262
stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE)
263263
}
264+
265+
266+
267+
#' Determine if a model can make predictions on sub-models
268+
#'
269+
#' @param object An object to test.
270+
#' @param ... Not currently used.
271+
#' @return A single logical value.
272+
#' @keywords internal
273+
#' @examples
274+
#' model_idea <- linear_reg() %>% set_engine("lm")
275+
#' has_multi_pred(model_idea)
276+
#' model_fit <- fit(model_idea, mpg ~ ., data = mtcars)
277+
#' has_multi_pred(model_fit)
278+
#' @importFrom utils methods
279+
#' @export
280+
has_multi_pred <- function(object, ...) {
281+
UseMethod("has_multi_pred")
282+
}
283+
284+
#' @export
285+
#' @rdname has_multi_pred
286+
has_multi_pred.default <- function(object, ...) {
287+
FALSE
288+
}
289+
290+
#' @export
291+
#' @rdname has_multi_pred
292+
has_multi_pred.model_fit <- function(object, ...) {
293+
existing_mthds <- utils::methods("multi_predict")
294+
tst <- paste0("multi_predict.", class(object))
295+
any(tst %in% existing_mthds)
296+
}
297+
298+
#' @export
299+
#' @rdname has_multi_pred
300+
has_multi_pred.workflow <- function(object, ...) {
301+
has_multi_pred(object$fit$model$model)
302+
}
303+
304+

man/has_multi_pred.Rd

Lines changed: 35 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_misc.R

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
2+
# ------------------------------------------------------------------------------
3+
4+
context("checking for multi_predict")
5+
6+
test_that('parsnip objects', {
7+
8+
lm_idea <- linear_reg() %>% set_engine("lm")
9+
expect_false(has_multi_pred(lm_idea))
10+
11+
lm_fit <- fit(lm_idea, mpg ~ ., data = mtcars)
12+
expect_false(has_multi_pred(lm_fit))
13+
expect_false(has_multi_pred(lm_fit$fit))
14+
15+
mars_fit <-
16+
mars(mode = "regression") %>%
17+
set_engine("earth") %>%
18+
fit(mpg ~ ., data = mtcars)
19+
expect_true(has_multi_pred(mars_fit))
20+
expect_false(has_multi_pred(mars_fit$fit))
21+
})
22+
23+
test_that('other objects', {
24+
25+
expect_false(has_multi_pred(NULL))
26+
expect_false(has_multi_pred(NA))
27+
28+
})
29+
30+
# ------------------------------------------------------------------------------
31+

0 commit comments

Comments
 (0)