Skip to content

Commit 3ebfbff

Browse files
committed
fixed multi-predict for knn
1 parent 174e311 commit 3ebfbff

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

R/nearest_neighbor.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ multi_predict._train.kknn <-
213213
}
214214

215215
knn_by_k <- function(k, object, new_data, type, ...) {
216-
object$fit$call$ks <- k
216+
object$fit$best.parameters$k <- k
217+
217218
predict(object, new_data = new_data, type = type, ...) %>%
218219
dplyr::mutate(neighbors = k, .row = dplyr::row_number()) %>%
219220
dplyr::select(.row, neighbors, dplyr::starts_with(".pred"))
@@ -244,8 +245,9 @@ min_grid.nearest_neighbor <- function(x, grid, ...) {
244245
min_grid_df <-
245246
dplyr::full_join(fit_only %>% rename(max_neighbor = neighbors), grid, by = fixed_args) %>%
246247
dplyr::filter(neighbors != max_neighbor) %>%
248+
dplyr::rename(sub_neighbors = neighbors, neighbors = max_neighbor) %>%
247249
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
248-
dplyr::summarize(.submodels = list(list(neighbors = neighbors))) %>%
250+
dplyr::summarize(.submodels = list(list(neighbors = sub_neighbors))) %>%
249251
dplyr::ungroup() %>%
250252
dplyr::full_join(fit_only, grid, by = fixed_args)
251253

0 commit comments

Comments
 (0)