@@ -113,3 +113,78 @@ test_that('kknn prediction', {
113113
114114 expect_equal(form_pred , predict(res_form , iris [1 : 5 , c(" Sepal.Width" , " Species" )])$ .pred )
115115})
116+
117+
118+ test_that(' kknn multi-predict' , {
119+
120+ skip_if_not_installed(" kknn" )
121+ library(kknn )
122+
123+ iris_te <- c(1 : 2 , 50 : 51 , 100 : 101 )
124+ k_vals <- 1 : 10
125+
126+ res_xy <- fit_xy(
127+ nearest_neighbor(mode = " classification" , neighbors = 3 ) %> %
128+ set_engine(" kknn" ),
129+ control = ctrl ,
130+ x = iris [- iris_te , num_pred ],
131+ y = iris $ Species [- iris_te ]
132+ )
133+
134+ pred_multi <- multi_predict(res_xy , iris [iris_te , num_pred ], neighbors = k_vals )
135+ expect_equal(pred_multi %> % unnest() %> % nrow(), length(iris_te ) * length(k_vals ))
136+ expect_equal(pred_multi %> % nrow(), length(iris_te ))
137+
138+ pred_uni <- predict(res_xy , iris [iris_te , num_pred ])
139+ pred_uni_obs <-
140+ pred_multi %> %
141+ mutate(.rows = row_number()) %> %
142+ unnest() %> %
143+ dplyr :: filter(neighbors == 3 ) %> %
144+ arrange(.rows ) %> %
145+ dplyr :: select(.pred_class )
146+ expect_equal(pred_uni , pred_uni_obs )
147+
148+
149+ prob_multi <- multi_predict(res_xy , iris [iris_te , num_pred ],
150+ neighbors = k_vals , type = " prob" )
151+ expect_equal(prob_multi %> % unnest() %> % nrow(), length(iris_te ) * length(k_vals ))
152+ expect_equal(prob_multi %> % nrow(), length(iris_te ))
153+
154+ prob_uni <- predict(res_xy , iris [iris_te , num_pred ], type = " prob" )
155+ prob_uni_obs <-
156+ prob_multi %> %
157+ mutate(.rows = row_number()) %> %
158+ unnest() %> %
159+ dplyr :: filter(neighbors == 3 ) %> %
160+ arrange(.rows ) %> %
161+ dplyr :: select(!! names(prob_uni ))
162+ expect_equal(prob_uni , prob_uni_obs )
163+
164+ # ----------------------------------------------------------------------------
165+ # regression
166+
167+ cars_te <- 1 : 5
168+ k_vals <- 1 : 10
169+
170+ res_xy <- fit(
171+ nearest_neighbor(mode = " regression" , neighbors = 3 ) %> %
172+ set_engine(" kknn" ),
173+ control = ctrl ,
174+ mpg ~ . , data = mtcars [- cars_te , ]
175+ )
176+
177+ pred_multi <- multi_predict(res_xy , mtcars [cars_te , - 1 ], neighbors = k_vals )
178+ expect_equal(pred_multi %> % unnest() %> % nrow(), length(cars_te ) * length(k_vals ))
179+ expect_equal(pred_multi %> % nrow(), length(cars_te ))
180+
181+ pred_uni <- predict(res_xy , mtcars [cars_te , - 1 ])
182+ pred_uni_obs <-
183+ pred_multi %> %
184+ mutate(.rows = row_number()) %> %
185+ unnest() %> %
186+ dplyr :: filter(neighbors == 3 ) %> %
187+ arrange(.rows ) %> %
188+ dplyr :: select(.pred )
189+ expect_equal(pred_uni , pred_uni_obs )
190+ })
0 commit comments