Skip to content

Commit 969c637

Browse files
authored
Test multiple seeds (#174)
* Test multiple rngs * Add comment * Fix rng not being passed correctly * Revert some changes * Update comment * Extend tests * Simplify test * Lower accuracy bound * Fix a bug in the usage of Mersenne Twister * Fix tests * Use some more StableRNG * Use some more StableRNG * Use some more StableRNG * Use some more StableRNG * Use some more StableRNG * Use some more StableRNG * Fix Julia 1.6 * Use `StableRNG` in test/classification/adult * Use `StableRNG` for data generation too * Put old numbers back * Add one more rng
1 parent a3398bf commit 969c637

File tree

17 files changed

+166
-124
lines changed

17 files changed

+166
-124
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
version:
19+
version:
2020
- '1.0'
2121
- '1.6'
2222
- '1' # automatically expands to the latest stable 1.x release of Julia

Project.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,19 @@ version = "0.10.12"
77
[deps]
88
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
99
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
10-
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1211
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1312
ScikitLearnBase = "6e75b9c4-186b-50bd-896f-2d2496a4843e"
1413
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
15-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1614

1715
[compat]
1816
AbstractTrees = "0.3"
1917
ScikitLearnBase = "0.5"
2018
julia = "1"
19+
20+
[extras]
21+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
22+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23+
24+
[targets]
25+
test = ["StableRNGs", "Test"]

src/classification/main.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ end
3737
function _convert(
3838
node :: treeclassifier.NodeMeta{S},
3939
list :: AbstractVector{T},
40-
labels :: AbstractVector{T}) where {S, T}
40+
labels :: AbstractVector{T}
41+
) where {S, T}
4142

4243
if node.is_leaf
4344
return Leaf{T}(list[node.label], labels[node.region])
@@ -138,7 +139,7 @@ function prune_tree(tree::LeafOrNode{S, T}, purity_thresh=1.0) where {S, T}
138139
end
139140

140141

141-
apply_tree(leaf::Leaf{T}, feature::AbstractVector{S}) where {S, T} = leaf.majority
142+
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.majority
142143

143144
function apply_tree(tree::Node{S, T}, features::AbstractVector{S}) where {S, T}
144145
if tree.featid == 0
@@ -197,7 +198,7 @@ function build_forest(
197198
min_samples_leaf = 1,
198199
min_samples_split = 2,
199200
min_purity_increase = 0.0;
200-
rng = Random.GLOBAL_RNG) where {S, T}
201+
rng::Union{Integer,AbstractRNG} = Random.GLOBAL_RNG) where {S, T}
201202

202203
if n_trees < 1
203204
throw("the number of trees must be >= 1")
@@ -221,7 +222,12 @@ function build_forest(
221222

222223
if rng isa Random.AbstractRNG
223224
Threads.@threads for i in 1:n_trees
224-
inds = rand(rng, 1:t_samples, n_samples)
225+
# The Mersenne Twister (Julia's default) is not thread-safe.
226+
_rng = copy(rng)
227+
# Take some elements from the ring to have different states for each tree.
228+
# This is the only way given that only a `copy` can be expected to exist for RNGs.
229+
rand(_rng, i)
230+
inds = rand(_rng, 1:t_samples, n_samples)
225231
forest[i] = build_tree(
226232
labels[inds],
227233
features[inds,:],
@@ -231,9 +237,9 @@ function build_forest(
231237
min_samples_split,
232238
min_purity_increase,
233239
loss = loss,
234-
rng = rng)
240+
rng = _rng)
235241
end
236-
elseif rng isa Integer # each thread gets its own seeded rng
242+
else # each thread gets its own seeded rng
237243
Threads.@threads for i in 1:n_trees
238244
Random.seed!(rng + i)
239245
inds = rand(1:t_samples, n_samples)
@@ -247,8 +253,6 @@ function build_forest(
247253
min_purity_increase,
248254
loss = loss)
249255
end
250-
else
251-
throw("rng must of be type Integer or Random.AbstractRNG")
252256
end
253257

254258
return Ensemble{S, T}(forest)
@@ -298,7 +302,7 @@ function build_adaboost_stumps(
298302
labels :: AbstractVector{T},
299303
features :: AbstractMatrix{S},
300304
n_iterations :: Integer;
301-
rng = Random.GLOBAL_RNG) where {S, T}
305+
rng = Random.GLOBAL_RNG) where {S, T}
302306
N = length(labels)
303307
weights = ones(N) / N
304308
stumps = Node{S, T}[]

src/measures.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ function _nfoldCV(classifier::Symbol, labels::AbstractVector{T}, features::Abstr
135135
predictions = apply_forest(model, test_features)
136136
elseif classifier == :stumps
137137
model, coeffs = build_adaboost_stumps(
138-
train_labels, train_features, n_iterations)
138+
train_labels, train_features, n_iterations; rng=rng)
139139
predictions = apply_adaboost_stumps(model, coeffs, test_features)
140140
end
141141
cm = confusion_matrix(test_labels, predictions)
@@ -186,6 +186,7 @@ function nfoldCV_stumps(
186186
n_iterations ::Integer = 10;
187187
verbose :: Bool = true,
188188
rng = Random.GLOBAL_RNG) where {S, T}
189+
rng = mk_rng(rng)::Random.AbstractRNG
189190
_nfoldCV(:stumps, labels, features, n_folds, n_iterations; verbose=verbose, rng=rng)
190191
end
191192

src/regression/main.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function build_forest(
5656
min_samples_leaf = 5,
5757
min_samples_split = 2,
5858
min_purity_increase = 0.0;
59-
rng = Random.GLOBAL_RNG) where {S, T <: Float64}
59+
rng::Union{Integer,AbstractRNG} = Random.GLOBAL_RNG) where {S, T <: Float64}
6060

6161
if n_trees < 1
6262
throw("the number of trees must be >= 1")
@@ -77,7 +77,12 @@ function build_forest(
7777

7878
if rng isa Random.AbstractRNG
7979
Threads.@threads for i in 1:n_trees
80-
inds = rand(rng, 1:t_samples, n_samples)
80+
# The Mersenne Twister (Julia's default) is not thread-safe.
81+
_rng = copy(rng)
82+
# Take some elements from the ring to have different states for each tree.
83+
# This is the only way given that only a `copy` can be expected to exist for RNGs.
84+
rand(_rng, i)
85+
inds = rand(_rng, 1:t_samples, n_samples)
8186
forest[i] = build_tree(
8287
labels[inds],
8388
features[inds,:],
@@ -86,9 +91,9 @@ function build_forest(
8691
min_samples_leaf,
8792
min_samples_split,
8893
min_purity_increase,
89-
rng = rng)
94+
rng = _rng)
9095
end
91-
elseif rng isa Integer # each thread gets its own seeded rng
96+
else # each thread gets its own seeded rng
9297
Threads.@threads for i in 1:n_trees
9398
Random.seed!(rng + i)
9499
inds = rand(1:t_samples, n_samples)
@@ -101,8 +106,6 @@ function build_forest(
101106
min_samples_split,
102107
min_purity_increase)
103108
end
104-
else
105-
throw("rng must of be type Integer or Random.AbstractRNG")
106109
end
107110

108111
return Ensemble{S, T}(forest)

test/classification/adult.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
features, labels = load_data("adult")
77

8-
model = build_tree(labels, features)
8+
model = build_tree(labels, features; rng=StableRNG(1))
99
preds = apply_tree(model, features)
1010
cm = confusion_matrix(labels, preds)
1111
@test cm.accuracy > 0.99
@@ -15,35 +15,35 @@ labels = string.(labels)
1515

1616
n_subfeatures = 3
1717
n_trees = 5
18-
model = build_forest(labels, features, n_subfeatures, n_trees)
18+
model = build_forest(labels, features, n_subfeatures, n_trees; rng=StableRNG(1))
1919
preds = apply_forest(model, features)
2020
cm = confusion_matrix(labels, preds)
2121
@test cm.accuracy > 0.9
2222

2323
n_iterations = 15
24-
model, coeffs = build_adaboost_stumps(labels, features, n_iterations);
24+
model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1));
2525
preds = apply_adaboost_stumps(model, coeffs, features);
2626
cm = confusion_matrix(labels, preds);
2727
@test cm.accuracy > 0.8
2828

2929
println("\n##### 3 foldCV Classification Tree #####")
3030
pruning_purity = 0.9
3131
nfolds = 3
32-
accuracy = nfoldCV_tree(labels, features, nfolds, pruning_purity; verbose=false);
32+
accuracy = nfoldCV_tree(labels, features, nfolds, pruning_purity; rng=StableRNG(1), verbose=false);
3333
@test mean(accuracy) > 0.8
3434

3535
println("\n##### 3 foldCV Classification Forest #####")
3636
n_subfeatures = 2
3737
n_trees = 10
3838
n_folds = 3
3939
partial_sampling = 0.5
40-
accuracy = nfoldCV_forest(labels, features, n_folds, n_subfeatures, n_trees, partial_sampling; verbose=false)
40+
accuracy = nfoldCV_forest(labels, features, n_folds, n_subfeatures, n_trees, partial_sampling; rng=StableRNG(1), verbose=false)
4141
@test mean(accuracy) > 0.8
4242

4343
println("\n##### nfoldCV Classification Adaboosted Stumps #####")
4444
n_iterations = 15
4545
n_folds = 3
46-
accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations; verbose=false);
46+
accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations; rng=StableRNG(1), verbose=false);
4747
@test mean(accuracy) > 0.8
4848

4949
end # @testset

test/classification/digits.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ model = DecisionTree.build_forest(
6969
max_depth,
7070
min_samples_leaf,
7171
min_samples_split,
72-
min_purity_increase)
72+
min_purity_increase;
73+
rng=StableRNG(1))
7374
preds = apply_forest(model, X)
7475
cm = confusion_matrix(Y, preds)
7576
@test cm.accuracy > 0.95

test/classification/heterogeneous.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,29 @@
55
m, n = 10^2, 5
66

77
tf = [trues(Int(m/2)) falses(Int(m/2))]
8-
inds = Random.randperm(m)
8+
inds = Random.randperm(StableRNG(1), m)
99
labels = string.(tf[inds])
1010

1111
features = Array{Any}(undef, m, n)
12-
features[:,:] = randn(m, n)
13-
features[:,2] = string.(tf[Random.randperm(m)])
12+
features[:,:] = randn(StableRNG(1), m, n)
13+
features[:,2] = string.(tf[Random.randperm(StableRNG(1), m)])
1414
features[:,3] = map(t -> round.(Int, t), features[:,3])
1515
features[:,4] = tf[inds]
1616

17-
model = build_tree(labels, features)
17+
model = build_tree(labels, features; rng=StableRNG(1))
1818
preds = apply_tree(model, features)
1919
cm = confusion_matrix(labels, preds)
2020
@test cm.accuracy > 0.9
2121

2222
n_subfeatures = 2
2323
n_trees = 3
24-
model = build_forest(labels, features, n_subfeatures, n_trees)
24+
model = build_forest(labels, features, n_subfeatures, n_trees; rng=StableRNG(1))
2525
preds = apply_forest(model, features)
2626
cm = confusion_matrix(labels, preds)
2727
@test cm.accuracy > 0.9
2828

2929
n_subfeatures = 7
30-
model, coeffs = build_adaboost_stumps(labels, features, n_subfeatures)
30+
model, coeffs = build_adaboost_stumps(labels, features, n_subfeatures; rng=StableRNG(1))
3131
preds = apply_adaboost_stumps(model, coeffs, features)
3232
cm = confusion_matrix(labels, preds)
3333
@test cm.accuracy > 0.9

test/classification/iris.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ cm = confusion_matrix(labels, preds)
5959
# run n-fold cross validation for pruned tree
6060
println("\n##### nfoldCV Classification Tree #####")
6161
nfolds = 3
62-
accuracy = nfoldCV_tree(labels, features, nfolds)
62+
accuracy = nfoldCV_tree(labels, features, nfolds; rng=StableRNG(1))
6363
@test mean(accuracy) > 0.8
6464

6565
# train random forest classifier
6666
n_trees = 10
6767
n_subfeatures = 2
6868
partial_sampling = 0.5
69-
model = build_forest(labels, features, n_subfeatures, n_trees, partial_sampling)
69+
model = build_forest(labels, features, n_subfeatures, n_trees, partial_sampling; rng=StableRNG(2))
7070
preds = apply_forest(model, features)
7171
cm = confusion_matrix(labels, preds)
7272
@test cm.accuracy > 0.95
@@ -80,12 +80,12 @@ n_subfeatures = 2
8080
n_trees = 10
8181
n_folds = 3
8282
partial_sampling = 0.5
83-
accuracy = nfoldCV_forest(labels, features, nfolds, n_subfeatures, n_trees, partial_sampling)
83+
accuracy = nfoldCV_forest(labels, features, nfolds, n_subfeatures, n_trees, partial_sampling; rng=StableRNG(1))
8484
@test mean(accuracy) > 0.9
8585

8686
# train adaptive-boosted decision stumps
8787
n_iterations = 15
88-
model, coeffs = build_adaboost_stumps(labels, features, n_iterations)
88+
model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1))
8989
preds = apply_adaboost_stumps(model, coeffs, features)
9090
cm = confusion_matrix(labels, preds)
9191
@test cm.accuracy > 0.9
@@ -97,7 +97,7 @@ probs = apply_adaboost_stumps_proba(model, coeffs, features, classes)
9797
println("\n##### nfoldCV Classification Adaboosted Stumps #####")
9898
n_iterations = 15
9999
nfolds = 3
100-
accuracy = nfoldCV_stumps(labels, features, nfolds, n_iterations)
100+
accuracy = nfoldCV_stumps(labels, features, nfolds, n_iterations; rng=StableRNG(1))
101101
@test mean(accuracy) > 0.85
102102

103103
end # @testset

test/classification/low_precision.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ Random.seed!(16)
55

66
n,m = 10^3, 5;
77
features = Array{Any}(undef, n, m);
8-
features[:,:] = rand(n, m);
8+
features[:,:] = rand(StableRNG(1), n, m);
99
features[:,1] = round.(Int32, features[:,1]); # convert a column of 32bit integers
10-
weights = rand(-1:1,m);
10+
weights = rand(StableRNG(1), -1:1, m);
1111
labels = round.(Int32, features * weights);
1212

1313
model = build_stump(labels, features)
@@ -25,7 +25,8 @@ model = build_tree(
2525
n_subfeatures, max_depth,
2626
min_samples_leaf,
2727
min_samples_split,
28-
min_purity_increase)
28+
min_purity_increase;
29+
rng=StableRNG(1))
2930
preds = apply_tree(model, features)
3031
cm = confusion_matrix(labels, preds)
3132
@test typeof(preds) == Vector{Int32}
@@ -40,14 +41,15 @@ model = build_forest(
4041
n_subfeatures,
4142
n_trees,
4243
partial_sampling,
43-
max_depth)
44+
max_depth;
45+
rng=StableRNG(1))
4446
preds = apply_forest(model, features)
4547
cm = confusion_matrix(labels, preds)
4648
@test typeof(preds) == Vector{Int32}
4749
@test cm.accuracy > 0.9
4850

4951
n_iterations = Int32(25)
50-
model, coeffs = build_adaboost_stumps(labels, features, n_iterations);
52+
model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1));
5153
preds = apply_adaboost_stumps(model, coeffs, features);
5254
cm = confusion_matrix(labels, preds)
5355
@test typeof(preds) == Vector{Int32}
@@ -67,7 +69,8 @@ accuracy = nfoldCV_tree(
6769
max_depth,
6870
min_samples_leaf,
6971
min_samples_split,
70-
min_purity_increase)
72+
min_purity_increase;
73+
rng=StableRNG(1))
7174
@test mean(accuracy) > 0.7
7275

7376
println("\n##### nfoldCV Classification Forest #####")
@@ -87,12 +90,13 @@ accuracy = nfoldCV_forest(
8790
max_depth,
8891
min_samples_leaf,
8992
min_samples_split,
90-
min_purity_increase)
93+
min_purity_increase;
94+
rng=StableRNG(1))
9195
@test mean(accuracy) > 0.7
9296

9397
println("\n##### nfoldCV Adaboosted Stumps #####")
9498
n_iterations = Int32(25)
95-
accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations)
99+
accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations; rng=StableRNG(1))
96100
@test mean(accuracy) > 0.6
97101

98102

0 commit comments

Comments
 (0)