@@ -33,3 +33,102 @@ multi_predict.default <- function(object, ...)
3333predict.model_spec <- function (object , ... ) {
3434 stop(" You must use `fit()` on your model specification before you can use `predict()`." , call. = FALSE )
3535}
36+
37+ # ' Tools for models that predict on sub-models
38+ # '
39+ # ' `has_multi_predict()` tests to see if an object can make multiple
40+ # ' predictions on submodels from the same object. `multi_predict_args()`
41+ # ' returns the names of the argments to `multi_predict()` for this model
42+ # ' (if any).
43+ # ' @param object An object to test.
44+ # ' @param ... Not currently used.
45+ # ' @return `has_multi_predict()` returns single logical value while
46+ # ' `multi_predict()` returns a character vector of argument names (or `NA`
47+ # ' if none exist).
48+ # ' @keywords internal
49+ # ' @examples
50+ # ' lm_model_idea <- linear_reg() %>% set_engine("lm")
51+ # ' has_multi_predict(lm_model_idea)
52+ # ' lm_model_fit <- fit(lm_model_idea, mpg ~ ., data = mtcars)
53+ # ' has_multi_predict(lm_model_fit)
54+ # '
55+ # ' multi_predict_args(lm_model_fit)
56+ # '
57+ # ' library(kknn)
58+ # '
59+ # ' knn_fit <-
60+ # ' nearest_neighbor(mode = "regression", neighbors = 5) %>%
61+ # ' set_engine("kknn") %>%
62+ # ' fit(mpg ~ ., mtcars)
63+ # '
64+ # ' multi_predict_args(knn_fit)
65+ # '
66+ # ' multi_predict(knn_fit, mtcars[1, -1], neighbors = 1:4)$.pred
67+ # ' @importFrom utils methods
68+ # ' @export
69+ has_multi_predict <- function (object , ... ) {
70+ UseMethod(" has_multi_predict" )
71+ }
72+
73+ # ' @export
74+ # ' @rdname has_multi_predict
75+ has_multi_predict.default <- function (object , ... ) {
76+ FALSE
77+ }
78+
79+ # ' @export
80+ # ' @rdname has_multi_predict
81+ has_multi_predict.model_fit <- function (object , ... ) {
82+ existing_mthds <- utils :: methods(" multi_predict" )
83+ tst <- paste0(" multi_predict." , class(object ))
84+ any(tst %in% existing_mthds )
85+ }
86+
87+ # ' @export
88+ # ' @rdname has_multi_predict
89+ has_multi_predict.workflow <- function (object , ... ) {
90+ has_multi_predict(object $ fit $ model $ model )
91+ }
92+
93+
94+ # ' @rdname has_multi_predict
95+ # ' @export
96+ # ' @rdname has_multi_predict
97+ multi_predict_args <- function (object , ... ) {
98+ UseMethod(" multi_predict_args" )
99+ }
100+
101+ # ' @export
102+ # ' @rdname has_multi_predict
103+ multi_predict_args.default <- function (object , ... ) {
104+ if (inherits(object , " model_fit" )) {
105+ res <- multi_predict_args.model_fit(object , ... )
106+ } else {
107+ res <- NA_character_
108+ }
109+ res
110+ }
111+
112+ # ' @export
113+ # ' @rdname has_multi_predict
114+ multi_predict_args.model_fit <- function (object , ... ) {
115+ existing_mthds <- methods(" multi_predict" )
116+ cls <- class(object )
117+ tst <- paste0(" multi_predict." , cls )
118+ .fn <- tst [tst %in% existing_mthds ]
119+ if (length(.fn ) == 0 ) {
120+ return (NA_character_ )
121+ }
122+
123+ .fn <- getFromNamespace(.fn , ns = " parsnip" )
124+ omit <- c(' object' , ' new_data' , ' type' , ' ...' )
125+ args <- names(formals(.fn ))
126+ args [! (args %in% omit )]
127+ }
128+
129+ # ' @export
130+ # ' @rdname has_multi_predict
131+ multi_predict_args.workflow <- function (object , ... ) {
132+ object <- object $ fit $ model $ model
133+
134+ }
0 commit comments