Skip to content

Commit ede3d8a

Browse files
committed
Updating for nearest_neighbor() engine specific args
1 parent 3a52078 commit ede3d8a

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

R/nearest_neighbor.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ update.nearest_neighbor <- function(object,
9696
weight_func = NULL,
9797
dist_power = NULL,
9898
fresh = FALSE, ...) {
99-
update_dot_check(...)
99+
100+
eng_args <- update_engine_parameters(object$eng_args, ...)
100101

101102
if (!is.null(parameters)) {
102103
parameters <- check_final_param(parameters)
@@ -112,12 +113,15 @@ update.nearest_neighbor <- function(object,
112113

113114
if (fresh) {
114115
object$args <- args
116+
object$eng_args <- eng_args
115117
} else {
116118
null_args <- map_lgl(args, null_value)
117119
if (any(null_args))
118120
args <- args[!null_args]
119121
if (length(args) > 0)
120122
object$args[names(args)] <- args
123+
if (length(eng_args) > 0)
124+
object$eng_args[names(eng_args)] <- eng_args
121125
}
122126

123127
new_model_spec(

tests/testthat/test_nearest_neighbor.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,18 @@ test_that('engine arguments', {
8181

8282
test_that('updating', {
8383

84-
expr1 <- nearest_neighbor() %>% set_engine("kknn", scale = FALSE)
84+
expr1 <- nearest_neighbor() %>% set_engine("kknn", scale = varying())
8585
expr1_exp <- nearest_neighbor(neighbors = 5) %>% set_engine("kknn", scale = FALSE)
8686

8787
expr2 <- nearest_neighbor(neighbors = varying()) %>% set_engine("kknn")
8888
expr2_exp <- nearest_neighbor(neighbors = varying(), weight_func = "triangular") %>% set_engine("kknn")
8989

90-
expr3 <- nearest_neighbor(neighbors = 2, weight_func = varying()) %>% set_engine("kknn")
91-
expr3_exp <- nearest_neighbor(neighbors = 3) %>% set_engine("kknn")
90+
expr3 <- nearest_neighbor(neighbors = 2, weight_func = varying()) %>% set_engine("kknn", scale = varying())
91+
expr3_exp <- nearest_neighbor(neighbors = 3) %>% set_engine("kknn", scale = FALSE)
9292

93-
expect_equal(update(expr1, neighbors = 5), expr1_exp)
93+
expect_equal(update(expr1, neighbors = 5, scale = FALSE), expr1_exp)
9494
expect_equal(update(expr2, weight_func = "triangular"), expr2_exp)
95-
expect_equal(update(expr3, neighbors = 3, fresh = TRUE), expr3_exp)
95+
expect_equal(update(expr3, neighbors = 3, fresh = TRUE, scale = FALSE), expr3_exp)
9696

9797
param_tibb <- tibble::tibble(neighbors = 7, dist_power = 1)
9898
param_list <- as.list(param_tibb)

0 commit comments

Comments
 (0)