Skip to content

Commit c6ec573

Browse files
committed
knn multi_predict
1 parent d0bc917 commit c6ec573

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

R/nearest_neighbor.R

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,43 @@ translate.nearest_neighbor <- function(x, engine = x$engine, ...) {
178178
}
179179
x
180180
}
181+
182+
183+
# ------------------------------------------------------------------------------
184+
185+
#' @importFrom purrr map_df
186+
#' @importFrom dplyr starts_with
187+
#' @rdname multi_predict
188+
#' @param neighbors An integer vector for the number of nearest neighbors.
189+
#' @export
190+
multi_predict._train.kknn <-
191+
function(object, new_data, type = NULL, neighbors = NULL, ...) {
192+
if (any(names(enquos(...)) == "newdata"))
193+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
194+
195+
if (is.null(neighbors))
196+
neighbors <- rlang::eval_tidy(tt$fit$call$ks)
197+
neighbors <- sort(neighbors)
198+
199+
if (is.null(type)) {
200+
if (object$spec$mode == "classification")
201+
type <- "class"
202+
else
203+
type <- "numeric"
204+
}
205+
206+
res <-
207+
purrr::map_df(neighbors, knn_by_k, object = object,
208+
new_data = new_data, type = type, ...)
209+
res <- dplyr::arrange(res, .row, neighbors)
210+
res <- split(res[, -1], res$.row)
211+
names(res) <- NULL
212+
dplyr::tibble(.pred = res)
213+
}
214+
215+
knn_by_k <- function(k, object, new_data, type, ...) {
216+
object$fit$call$ks <- k
217+
predict(object, new_data = new_data, type = type, ...) %>%
218+
dplyr::mutate(neighbors = k, .row = dplyr::row_number()) %>%
219+
dplyr::select(.row, neighbors, dplyr::starts_with(".pred"))
220+
}

0 commit comments

Comments
 (0)