Skip to content

Commit abb2fea

Browse files
committed
Use SAMME algorithm for adaboost
1 parent 7e090bb commit abb2fea

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/classification/main.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,17 +300,23 @@ function build_adaboost_stumps(
300300
n_iterations :: Integer;
301301
rng = Random.GLOBAL_RNG) where {S, T}
302302
N = length(labels)
303+
n_labels = length(unique(labels))
304+
base_coeff = log(n_labels - 1)
305+
thresh = 1 - 1 / n_labels
303306
weights = ones(N) / N
304307
stumps = Node{S, T}[]
305308
coeffs = Float64[]
306309
for i in 1:n_iterations
307310
new_stump = build_stump(labels, features, weights; rng=rng)
308311
predictions = apply_tree(new_stump, features)
309312
err = _weighted_error(labels, predictions, weights)
310-
new_coeff = 0.5 * log((1.0 + err) / (1.0 - err))
311-
matches = labels .== predictions
312-
weights[(!).(matches)] *= exp(new_coeff)
313-
weights[matches] *= exp(-new_coeff)
313+
if err >= thresh # should be better than random guess
314+
continue
315+
end
316+
# SAMME algorithm
317+
new_coeff = log((1.0 - err) / err) + base_coeff
318+
unmatches = labels .!= predictions
319+
weights[unmatches] *= exp(new_coeff)
314320
weights /= sum(weights)
315321
push!(coeffs, new_coeff)
316322
push!(stumps, new_stump)

0 commit comments

Comments
 (0)