Skip to content

Commit 904146d

Browse files
author
ercbk
committed
finished raschka portion of performance experiment
1 parent ad4476e commit 904146d

File tree

11 files changed

+462
-63
lines changed

11 files changed

+462
-63
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
.Ruserdata
55
.env
66
.drake
7+
.drake-raschka
8+
README.html
79
ec2-ssh-raw.log
810
README_cache
911
check-results.R

README.Rmd

Lines changed: 134 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,38 +9,39 @@ output: github_document
99

1010
Nested cross-validation has become a recommended technique for situations in which the size of our dataset is insufficient to simultaneously handle hyperparameter tuning and algorithm comparison. Examples of such situations include: proof of concept, start-ups, medical studies, time series, etc. Using standard methods such as k-fold cross-validation in these cases may result in substantial increases in optimization bias. Nested cross-validation has been shown to produce less biased, out-of-sample error estimates even using datasets with only hundreds of rows and therefore gives a better judgement of generalization performance.
1111

12-
The primary issue with this technique is that it is computationally very expensive with potentially tens of 1000s of models being trained during the process. While researching this technique, I found two slightly different methods of performing nested cross-validation — one authored by [Sabastian Raschka](https://github.com/rasbt/stat479-machine-learning-fs19/blob/master/11_eval4-algo/code/11-eval4-algo__nested-cv_verbose1.ipynb) and the other by [Max Kuhn and Kjell Johnson](https://tidymodels.github.io/rsample/articles/Applications/Nested_Resampling.html).
12+
The primary issue with this technique is that it can be computationally expensive with potentially tens of 1000s of models being trained during the process. While researching this technique, I found two slightly different variations of performing nested cross-validation — one authored by [Sabastian Raschka](https://github.com/rasbt/stat479-machine-learning-fs19/blob/master/11_eval4-algo/code/11-eval4-algo__nested-cv_verbose1.ipynb) and the other by [Max Kuhn and Kjell Johnson](https://tidymodels.github.io/rsample/articles/Applications/Nested_Resampling.html).
13+
14+
Various elements of the technique affect the run times and can be altered to improve performance. These include:
15+
16+
1. Hyperparameter value grids
17+
2. Grid search strategy
18+
3. Inner-Loop CV strategy
19+
4. Outer-Loop CV strategy
20+
1321
I'll be examining two aspects of nested cross-validation:
1422

1523
1. Duration: Find out which packages and combinations of model functions give us the fastest implementation of each method.
16-
2. Performance: First, develop a testing framework. Then, using a generated dataset, calculate how many repeats, given the sample size, should we expect to need in order to obtain a reasonably accurate out-of-sample error estimate.
24+
2. Performance: First, develop a testing framework. Then, for a given data generating process, how large of sample size is needed to obtain reasonably accurate out-of-sample error estimate? And how many repeats in the outer-loop cv strategy should be used to calculate this error estimate?
1725

1826

19-
## Duration Experiment
27+
## Duration
2028
#### Experiment details:
2129

2230
* Random Forest and Elastic Net Regression algorithms
23-
* Both with 100x2 hyperparameter grids
31+
* Both algorithms are tuned with 100x2 hyperparameter grids using a latin hypercube design.
32+
* From {mlbench}, I'm using the generated data set, friedman1, from Friedman's Multivariate Adaptive Regression Splines (MARS) paper.
2433
* Kuhn-Johnson
25-
+ 100 observations 10 features, numeric target variable
34+
+ 100 observations: 10 features, numeric target variable
2635
+ outer loop: 2 repeats, 10 folds
2736
+ inner loop: 25 bootstrap resamples
2837
* Raschka
2938
+ 5000 observations: 10 features, numeric target variable
3039
+ outer loop: 5 folds
3140
+ inner loop: 2 folds
3241

33-
The sizes of the data sets are the same as those in the original scripts by the authors. [MLFlow](https://mlflow.org/docs/latest/index.html) is used to keep track of the duration (seconds) of each run along with the implementation and method used.
34-
42+
The sizes of the data sets are the same as those in the original scripts by the authors. Using Kuhn-Johnson, 50,000 models (grid size * number of repeats * number of folds in the outer-loop * number of folds/resamples in the inner-loop) are trained for each algorithm — using Raschka's, 1,001 models for each algorithm. The one extra model in the Raschka variation is due to his method of choosing the hyperparameter values for the final model. He performs an extra k-fold cross-validation using the inner-loop cv strategy on the entire training set. Kuhn-Johnson uses majority vote. Whichever set of hyperparameter values has been chosen during the inner-loop tuning procedure the most often is the set used to fit the final model.
3543

36-
Various elements of the technique can be altered to improve performance. These include:
37-
38-
1. Hyperparameter value grids
39-
2. Outer-Loop CV strategy
40-
3. Inner-Loop CV strategy
41-
4. Grid search strategy
42-
43-
These elements also affect the run times. Both methods are using the same size grids, but Kuhn-Johnson uses repeats and more folds in the outer and inner loops while Raschka's trains an extra model over the entire training set at the end at the end. Using Kuhn-Johnson, 50,000 models (grid size * number of repeats * number of folds in the outer-loop * number of folds/resamples in the inner-loop) are trained for each algorithm — using Raschka's, 1,001 models.
44+
[MLFlow](https://mlflow.org/docs/latest/index.html) is used to keep track of the duration (seconds) of each run along with the implementation and method used.
4445

4546
![](duration-experiment/outputs/0225-results.png)
4647

@@ -104,18 +105,20 @@ durations
104105
```
105106

106107

107-
## Performance Experiment
108+
## Performance
108109

109110
#### Experiment details:
110111

112+
* The same data, algorithms, and hyperparameter grids are used.
111113
* The fastest implementation of each method is used in running a nested cross-validation with different sizes of data ranging from 100 to 5000 observations and different numbers of repeats of the outer-loop cv strategy.
112114
* The {mlr3} implementation is the fastest for Raschka's method, but the Ranger-Kuhn-Johnson implementation is close. To simplify, I am using [Ranger-Kuhn-Johnson](https://github.com/ercbk/nested-cross-validation-comparison/blob/master/duration-experiment/kuhn-johnson/nested-cv-ranger-kj.R) for both methods.
113-
* The chosen algorithm and hyperparameters predicts on a 100K row simulated dataset.
115+
* The chosen algorithm with hyperparameters is fit on the entire training set, and the resulting final model predicts on a 100K row Friedman dataset.
114116
* The percent error between the the average mean absolute error (MAE) across the outer-loop folds and the MAE of the predictions on this 100K dataset is calculated for each combination of repeat, data size, and method.
115117
* To make this experiment manageable in terms of runtimes, I am using AWS instances: a r5.2xlarge for the Elastic Net and a r5.24xlarge for Random Forest.
118+
+ Also see the Other Notes section
116119
* Iterating through different numbers of repeats, sample sizes, and methods makes a functional approach more appropriate than running imperative scripts. Also, given the long runtimes and impermanent nature of my internet connection, it would also be nice to cache each iteration as it finishes. The [{drake}](https://github.com/ropensci/drake) package is superb on both counts, so I'm using it to orchestrate.
117120

118-
```{r perf_build_times, echo=FALSE, message=FALSE}
121+
```{r perf_build_times_kj, echo=FALSE, message=FALSE}
119122
120123
pacman::p_load(extrafont,dplyr, purrr, lubridate, ggplot2, ggfittext, drake, patchwork)
121124
bt <- build_times(starts_with("ncv_results"), digits = 4)
@@ -143,7 +146,7 @@ readr::write_csv(subtargets, "performance-experiment/output/perf-exp-output.csv"
143146
144147
```
145148

146-
```{r perf_bt_charts, echo=FALSE, message=FALSE}
149+
```{r perf_bt_charts_kj, echo=FALSE, message=FALSE}
147150
148151
fill_colors <- unname(swatches::read_ase("palettes/Forest Floor.ase"))
149152
@@ -172,7 +175,7 @@ b <- ggplot(subtargets, aes(y = elapsed, x = repeats,
172175
173176
```
174177

175-
```{r perf-error-line, echo=FALSE, message=FALSE}
178+
```{r perf_error_line_kj, echo=FALSE, message=FALSE}
176179
e <- ggplot(subtargets, aes(x = repeats, y = percent_error, group = n)) +
177180
geom_point(aes(color = n), size = 3) +
178181
geom_line(aes(color = n), size = 2) +
@@ -196,7 +199,7 @@ e <- ggplot(subtargets, aes(x = repeats, y = percent_error, group = n)) +
196199
)
197200
```
198201

199-
```{r kj-patch, echo=FALSE, fig.width=12, fig.height=7}
202+
```{r kj_patch_kj, echo=FALSE, fig.width=12, fig.height=7}
200203
b + e + plot_layout(guides = "auto") +
201204
plot_annotation(title = "Kuhn-Johnson") &
202205
theme(legend.position = "top",
@@ -212,14 +215,122 @@ b + e + plot_layout(guides = "auto") +
212215

213216
#### Results:
214217

215-
Kuhn-Johnson:
216-
217218
* Runtimes for n = 100 and n = 800 are close, and there's a large jump in runtime going from n = 2000 to n = 5000.
218219
* The number of repeats has little effect on the amount of percent error.
219220
* For n = 100, there is substantially more variation in percent error than in the other sample sizes.
220221
* While there is a large runtime cost that comes with increasing the sample size from 2000 to 5000 obsservations, it doesn't seem to provide any benefit in gaining a more accurate estimate of the out-of-sample error.
221222

222223

224+
```{r perf_build_times_r, echo=FALSE, message=FALSE}
225+
226+
cache_raschka <- drake_cache(path = ".drake-raschka")
227+
228+
bt_r <- build_times(starts_with("ncv_results"),
229+
digits = 4, cache = cache_raschka)
230+
231+
subtarget_bts_r <- bt_r %>%
232+
filter(stringr::str_detect(target, pattern = "[0-9]_([0-9]|[a-z])")) %>%
233+
select(target, elapsed)
234+
235+
subtargets_raw_r <- map_dfr(subtarget_bts_r$target, function(x) {
236+
results <- readd(x, character_only = TRUE,
237+
cache = cache_raschka) %>%
238+
mutate(subtarget = x) %>%
239+
select(subtarget, everything())
240+
241+
}) %>%
242+
inner_join(subtarget_bts_r, by = c("subtarget" = "target"))
243+
244+
subtargets_r <- subtargets_raw_r %>%
245+
mutate(repeats = factor(repeats),
246+
n = factor(n),
247+
elapsed = round(as.numeric(elapsed)/3600, 2),
248+
percent_error = round(delta_error/oos_error, 3))
249+
250+
readr::write_csv(subtargets_r, "performance-experiment/output/perf-exp-output-r.csv")
251+
# readr::write_rds(subtargets, "performance-experiment/output/perf-exp-output-backup-r.rds")
252+
253+
```
254+
255+
```{r perf_bt_charts_r, echo=FALSE, message=FALSE}
256+
257+
b_r <- ggplot(subtargets_r, aes(y = elapsed, x = repeats,
258+
fill = n, label = elapsed)) +
259+
geom_col(position = position_dodge(width = 0.85)) +
260+
scale_fill_manual(values = fill_colors[4:7]) +
261+
geom_bar_text(position = "dodge", min.size = 9,
262+
place = "right", contrast = TRUE) +
263+
coord_flip() +
264+
labs(y = "Runtime (hrs)", x = "Repeats",
265+
fill = "Sample Size") +
266+
theme(title = element_text(family = "Roboto"),
267+
text = element_text(family = "Roboto"),
268+
legend.position = "top",
269+
legend.background = element_rect(fill = "ivory"),
270+
legend.key = element_rect(fill = "ivory"),
271+
axis.ticks = element_blank(),
272+
panel.background = element_rect(fill = "ivory",
273+
colour = "ivory"),
274+
plot.background = element_rect(fill = "ivory"),
275+
panel.border = element_blank(),
276+
panel.grid.major = element_blank(),
277+
panel.grid.minor = element_blank()
278+
)
279+
280+
```
281+
282+
```{r perf-error-line_r, echo=FALSE, message=FALSE}
283+
e_r <- ggplot(subtargets_r, aes(x = repeats, y = percent_error, group = n)) +
284+
geom_point(aes(color = n), size = 3) +
285+
geom_line(aes(color = n), size = 2) +
286+
expand_limits(y = c(0, 0.10)) +
287+
scale_y_continuous(labels = scales::percent_format(accuracy = 0.1),
288+
breaks = seq(0,0.125, by=0.025)) +
289+
scale_color_manual(values = fill_colors[4:7]) +
290+
labs(y = "Percent Error", x = "Repeats",
291+
color = "Sample Size") +
292+
theme(title = element_text(family = "Roboto"),
293+
text = element_text(family = "Roboto"),
294+
legend.position = "top",
295+
legend.background = element_rect(fill = "ivory"),
296+
legend.key = element_rect(fill = "ivory"),
297+
axis.ticks = element_blank(),
298+
panel.background = element_rect(fill = "ivory",
299+
color = "ivory"),
300+
plot.background = element_rect(fill = "ivory"),
301+
panel.border = element_blank(),
302+
panel.grid.major = element_blank(),
303+
panel.grid.minor = element_blank()
304+
)
305+
```
306+
307+
```{r kj-patch, echo=FALSE, fig.width=12, fig.height=7}
308+
b_r + e_r + plot_layout(guides = "auto") +
309+
plot_annotation(title = "Raschka") &
310+
theme(legend.position = "top",
311+
legend.text = element_text(size = 12),
312+
axis.text.x = element_text(size = 11,
313+
face = "bold"),
314+
axis.text.y = element_text(size = 11,
315+
face = "bold"),
316+
panel.background = element_rect(fill = "ivory",
317+
color = "ivory"),
318+
plot.background = element_rect(fill = "ivory"),)
319+
```
320+
321+
322+
#### Results:
323+
324+
* The longest runtime is under 30 minutes, so runtime isn't a large consideration if we are making a choice about sample size.
325+
* There isn't much difference in runtime between n = 100 and n = 2000.
326+
* For n = 100, there's a relatively large change in percent error when going from 1 repeat to 2 repeats. The error estimate then stabilizes for repeats 3 through 5.
327+
* n = 5000 gives poorer out-of-sample error estimates than n = 800 and n = 2000 for all values of repeats.
328+
* n = 800 remains under 2.5% percent error for all repeat values, but also shows considerable volatility.
329+
330+
331+
332+
333+
223334

224335

225336
References

0 commit comments

Comments
 (0)