Skip to content

Commit f120919

Browse files
committed
added unit tests
1 parent cc39a45 commit f120919

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

tests/testthat/test_nearest_neighbor_kknn.R

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,78 @@ test_that('kknn prediction', {
113113

114114
expect_equal(form_pred, predict(res_form, iris[1:5, c("Sepal.Width", "Species")])$.pred)
115115
})
116+
117+
118+
test_that('kknn multi-predict', {
119+
120+
skip_if_not_installed("kknn")
121+
library(kknn)
122+
123+
iris_te <- c(1:2, 50:51, 100:101)
124+
k_vals <- 1:10
125+
126+
res_xy <- fit_xy(
127+
nearest_neighbor(mode = "classification", neighbors = 3) %>%
128+
set_engine("kknn"),
129+
control = ctrl,
130+
x = iris[-iris_te, num_pred],
131+
y = iris$Species[-iris_te]
132+
)
133+
134+
pred_multi <- multi_predict(res_xy, iris[iris_te, num_pred], neighbors = k_vals)
135+
expect_equal(pred_multi %>% unnest() %>% nrow(), length(iris_te) * length(k_vals))
136+
expect_equal(pred_multi %>% nrow(), length(iris_te))
137+
138+
pred_uni <- predict(res_xy, iris[iris_te, num_pred])
139+
pred_uni_obs <-
140+
pred_multi %>%
141+
mutate(.rows = row_number()) %>%
142+
unnest() %>%
143+
dplyr::filter(neighbors == 3) %>%
144+
arrange(.rows) %>%
145+
dplyr::select(.pred_class)
146+
expect_equal(pred_uni, pred_uni_obs)
147+
148+
149+
prob_multi <- multi_predict(res_xy, iris[iris_te, num_pred],
150+
neighbors = k_vals, type = "prob")
151+
expect_equal(prob_multi %>% unnest() %>% nrow(), length(iris_te) * length(k_vals))
152+
expect_equal(prob_multi %>% nrow(), length(iris_te))
153+
154+
prob_uni <- predict(res_xy, iris[iris_te, num_pred], type = "prob")
155+
prob_uni_obs <-
156+
prob_multi %>%
157+
mutate(.rows = row_number()) %>%
158+
unnest() %>%
159+
dplyr::filter(neighbors == 3) %>%
160+
arrange(.rows) %>%
161+
dplyr::select(!!names(prob_uni))
162+
expect_equal(prob_uni, prob_uni_obs)
163+
164+
# ----------------------------------------------------------------------------
165+
# regression
166+
167+
cars_te <- 1:5
168+
k_vals <- 1:10
169+
170+
res_xy <- fit(
171+
nearest_neighbor(mode = "regression", neighbors = 3) %>%
172+
set_engine("kknn"),
173+
control = ctrl,
174+
mpg ~ ., data = mtcars[-cars_te, ]
175+
)
176+
177+
pred_multi <- multi_predict(res_xy, mtcars[cars_te, -1], neighbors = k_vals)
178+
expect_equal(pred_multi %>% unnest() %>% nrow(), length(cars_te) * length(k_vals))
179+
expect_equal(pred_multi %>% nrow(), length(cars_te))
180+
181+
pred_uni <- predict(res_xy, mtcars[cars_te, -1])
182+
pred_uni_obs <-
183+
pred_multi %>%
184+
mutate(.rows = row_number()) %>%
185+
unnest() %>%
186+
dplyr::filter(neighbors == 3) %>%
187+
arrange(.rows) %>%
188+
dplyr::select(.pred)
189+
expect_equal(pred_uni, pred_uni_obs)
190+
})

0 commit comments

Comments
 (0)