Skip to content

Commit b647fc7

Browse files
authored
Merge pull request #183 from tidymodels/knn-multi-predict
KNN multi predict
2 parents 7829acb + e6d93c7 commit b647fc7

File tree

13 files changed

+209
-3
lines changed

13 files changed

+209
-3
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ S3method(multi_predict,"_earth")
1010
S3method(multi_predict,"_elnet")
1111
S3method(multi_predict,"_lognet")
1212
S3method(multi_predict,"_multnet")
13+
S3method(multi_predict,"_train.kknn")
1314
S3method(multi_predict,"_xgb.Booster")
1415
S3method(multi_predict,default)
1516
S3method(nullmodel,default)

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
* `add_rowindex()` can create a column called `.row` to a data frame.
1212

1313
* If a computational engine is not explicitly set, a default will be used. Each default is documented on the corresponding model page. A warning is issued at fit time unless verbosity is zero.
14+
* `nearest_neighbor` gained a `multi_predict` method. The `multi_predict()` documentation is a little better organized.
1415

1516

1617
# parsnip 0.0.2

R/aaa.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
2323
#' @importFrom utils globalVariables
2424
utils::globalVariables(
2525
c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group',
26-
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type')
26+
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
27+
"neighbors")
2728
)

R/aaa_multi_predict.R

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Define a generic to make multiple predictions for the same model object ------
2+
3+
#' Model predictions across many sub-models
4+
#'
5+
#' For some models, predictions can be made on sub-models in the model object.
6+
#' @param object A `model_fit` object.
7+
#' @param new_data A rectangular data object, such as a data frame.
8+
#' @param type A single character value or `NULL`. Possible values
9+
#' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
10+
#' or "raw". When `NULL`, `predict()` will choose an appropriate value
11+
#' based on the model's mode.
12+
#' @param ... Optional arguments to pass to `predict.model_fit(type = "raw")`
13+
#' such as `type`.
14+
#' @return A tibble with the same number of rows as the data being predicted.
15+
#' Mostly likely, there is a list-column named `.pred` that is a tibble with
16+
#' multiple rows per sub-model.
17+
#' @export
18+
multi_predict <- function(object, ...) {
19+
if (inherits(object$fit, "try-error")) {
20+
warning("Model fit failed; cannot make predictions.", call. = FALSE)
21+
return(NULL)
22+
}
23+
UseMethod("multi_predict")
24+
}
25+
26+
#' @export
27+
#' @rdname multi_predict
28+
multi_predict.default <- function(object, ...)
29+
stop("No `multi_predict` method exists for objects with classes ",
30+
paste0("'", class(), "'", collapse = ", "), call. = FALSE)
31+
32+
#' @export
33+
predict.model_spec <- function(object, ...) {
34+
stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE)
35+
}

R/boost_tree.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,8 @@ xgb_pred <- function(object, newdata, ...) {
366366

367367
#' @importFrom purrr map_df
368368
#' @export
369+
#' @rdname multi_predict
370+
#' @param trees An integer vector for the number of trees in the ensemble.
369371
multi_predict._xgb.Booster <-
370372
function(object, new_data, type = NULL, trees = NULL, ...) {
371373
if (any(names(enquos(...)) == "newdata")) {
@@ -474,6 +476,7 @@ C5.0_train <-
474476
}
475477

476478
#' @export
479+
#' @rdname multi_predict
477480
multi_predict._C5.0 <-
478481
function(object, new_data, type = NULL, trees = NULL, ...) {
479482
if (any(names(enquos(...)) == "newdata"))

R/linear_reg.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,8 @@ predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
302302
#' @importFrom dplyr full_join as_tibble arrange
303303
#' @importFrom tidyr gather
304304
#' @export
305+
#'@rdname multi_predict
306+
#' @param penalty An numeric vector of penalty values.
305307
multi_predict._elnet <-
306308
function(object, new_data, type = NULL, penalty = NULL, ...) {
307309
if (any(names(enquos(...)) == "newdata"))

R/logistic_reg.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ predict._lognet <- function (object, new_data, type = NULL, opts = list(), penal
276276
#' @importFrom dplyr full_join as_tibble arrange
277277
#' @importFrom tidyr gather
278278
#' @export
279+
#' @rdname multi_predict
279280
multi_predict._lognet <-
280281
function(object, new_data, type = NULL, penalty = NULL, ...) {
281282
if (any(names(enquos(...)) == "newdata"))

R/mars.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ earth_reg_updater <- function(num, object, new_data, ...) {
205205

206206
#' @importFrom purrr map_df
207207
#' @importFrom dplyr arrange
208+
#' @rdname multi_predict
209+
#' @param num_terms An integer vector for the number of MARS terms to retain.
208210
#' @export
209211
multi_predict._earth <-
210212
function(object, new_data, type = NULL, num_terms = NULL, ...) {

R/multinom_reg.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ predict._multnet <-
232232
#' @importFrom dplyr full_join as_tibble arrange
233233
#' @importFrom tidyr gather
234234
#' @export
235+
#' @rdname multi_predict
235236
multi_predict._multnet <-
236237
function(object, new_data, type = NULL, penalty = NULL, ...) {
237238
if (any(names(enquos(...)) == "newdata"))

R/nearest_neighbor.R

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,43 @@ translate.nearest_neighbor <- function(x, engine = x$engine, ...) {
178178
}
179179
x
180180
}
181+
182+
183+
# ------------------------------------------------------------------------------
184+
185+
#' @importFrom purrr map_df
186+
#' @importFrom dplyr starts_with
187+
#' @rdname multi_predict
188+
#' @param neighbors An integer vector for the number of nearest neighbors.
189+
#' @export
190+
multi_predict._train.kknn <-
191+
function(object, new_data, type = NULL, neighbors = NULL, ...) {
192+
if (any(names(enquos(...)) == "newdata"))
193+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
194+
195+
if (is.null(neighbors))
196+
neighbors <- rlang::eval_tidy(object$fit$call$ks)
197+
neighbors <- sort(neighbors)
198+
199+
if (is.null(type)) {
200+
if (object$spec$mode == "classification")
201+
type <- "class"
202+
else
203+
type <- "numeric"
204+
}
205+
206+
res <-
207+
purrr::map_df(neighbors, knn_by_k, object = object,
208+
new_data = new_data, type = type, ...)
209+
res <- dplyr::arrange(res, .row, neighbors)
210+
res <- split(res[, -1], res$.row)
211+
names(res) <- NULL
212+
dplyr::tibble(.pred = res)
213+
}
214+
215+
knn_by_k <- function(k, object, new_data, type, ...) {
216+
object$fit$call$ks <- k
217+
predict(object, new_data = new_data, type = type, ...) %>%
218+
dplyr::mutate(neighbors = k, .row = dplyr::row_number()) %>%
219+
dplyr::select(.row, neighbors, dplyr::starts_with(".pred"))
220+
}

0 commit comments

Comments
 (0)