@@ -178,3 +178,43 @@ translate.nearest_neighbor <- function(x, engine = x$engine, ...) {
178178 }
179179 x
180180}
181+
182+
183+ # ------------------------------------------------------------------------------
184+
185+ # ' @importFrom purrr map_df
186+ # ' @importFrom dplyr starts_with
187+ # ' @rdname multi_predict
188+ # ' @param neighbors An integer vector for the number of nearest neighbors.
189+ # ' @export
190+ multi_predict._train.kknn <-
191+ function (object , new_data , type = NULL , neighbors = NULL , ... ) {
192+ if (any(names(enquos(... )) == " newdata" ))
193+ stop(" Did you mean to use `new_data` instead of `newdata`?" , call. = FALSE )
194+
195+ if (is.null(neighbors ))
196+ neighbors <- rlang :: eval_tidy(tt $ fit $ call $ ks )
197+ neighbors <- sort(neighbors )
198+
199+ if (is.null(type )) {
200+ if (object $ spec $ mode == " classification" )
201+ type <- " class"
202+ else
203+ type <- " numeric"
204+ }
205+
206+ res <-
207+ purrr :: map_df(neighbors , knn_by_k , object = object ,
208+ new_data = new_data , type = type , ... )
209+ res <- dplyr :: arrange(res , .row , neighbors )
210+ res <- split(res [, - 1 ], res $ .row )
211+ names(res ) <- NULL
212+ dplyr :: tibble(.pred = res )
213+ }
214+
215+ knn_by_k <- function (k , object , new_data , type , ... ) {
216+ object $ fit $ call $ ks <- k
217+ predict(object , new_data = new_data , type = type , ... ) %> %
218+ dplyr :: mutate(neighbors = k , .row = dplyr :: row_number()) %> %
219+ dplyr :: select(.row , neighbors , dplyr :: starts_with(" .pred" ))
220+ }
0 commit comments