Skip to content

Commit 061bef7

Browse files
gardbergxuyxu
andauthored
feat: add hard majority voting strategy for parallel ensembles (#126)
* initial attempt to implementing hard majority voting * fixed linting * removed comments etc * moved hard voting to utils, added test * fixed linting * fixed linting 2 * add changelog * switched to value error * flake8 formating Co-authored-by: xuyxu <xuyx@lamda.nju.edu.cn>
1 parent 3a96d34 commit 061bef7

File tree

6 files changed

+89
-8
lines changed

6 files changed

+89
-8
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Changelog
1818
Ver 0.1.*
1919
---------
2020

21+
* |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg <https://github.com/LukasGardberg>`__
2122
* |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe <https://github.com/SunHaozhe>`__
2223
* |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu <https://github.com/xuyxu>`__
2324
* |Fix| Relax check on input dataloader | `@xuyxu <https://github.com/xuyxu>`__

torchensemble/_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def predict(self, *x):
203203
class BaseTreeEnsemble(BaseModule):
204204
def __init__(
205205
self,
206-
n_estimators,
206+
n_estimators=10,
207207
depth=5,
208208
lamda=1e-3,
209209
cuda=False,
@@ -280,8 +280,11 @@ def evaluate(self, test_loader, return_loss=False):
280280

281281
for _, elem in enumerate(test_loader):
282282
data, target = split_data_target(elem, self.device)
283+
283284
output = self.forward(*data)
285+
284286
_, predicted = torch.max(output.data, 1)
287+
285288
correct += (predicted == target).sum().item()
286289
total += target.size(0)
287290
loss += self._criterion(output, target)

torchensemble/snapshot_ensemble.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,28 @@ def set_scheduler(self, scheduler_name, **kwargs):
209209
"""Implementation on the SnapshotEnsembleClassifier.""", "seq_model"
210210
)
211211
class SnapshotEnsembleClassifier(_BaseSnapshotEnsemble, BaseClassifier):
212+
def __init__(self, voting_strategy="soft", **kwargs):
213+
super().__init__(**kwargs)
214+
215+
self.voting_strategy = voting_strategy
216+
212217
@torchensemble_model_doc(
213218
"""Implementation on the data forwarding in SnapshotEnsembleClassifier.""", # noqa: E501
214219
"classifier_forward",
215220
)
216221
def forward(self, *x):
217-
proba = self._forward(*x)
218222

219-
return F.softmax(proba, dim=1)
223+
outputs = [
224+
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
225+
]
226+
227+
if self.voting_strategy == "soft":
228+
proba = op.average(outputs)
229+
230+
elif self.voting_strategy == "hard":
231+
proba = op.majority_vote(outputs)
232+
233+
return proba
220234

221235
@torchensemble_model_doc(
222236
"""Set the attributes on optimizer for SnapshotEnsembleClassifier.""",

torchensemble/tests/test_operator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,14 @@ def test_residual_regression_invalid_shape():
4343
label.view(-1, 1), # 4 * 1
4444
)
4545
assert "should be the same as output" in str(excinfo.value)
46+
47+
48+
def test_majority_voting():
49+
outputs = [
50+
torch.FloatTensor(np.array(([0.9, 0.1], [0.2, 0.8]))),
51+
torch.FloatTensor(np.array(([0.7, 0.3], [0.1, 0.9]))),
52+
torch.FloatTensor(np.array(([0.1, 0.9], [0.8, 0.2]))),
53+
]
54+
actual = op.majority_vote(outputs).numpy()
55+
expected = np.array(([1, 0], [0, 1]))
56+
assert_array_almost_equal(actual, expected)

torchensemble/utils/operator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
import torch.nn.functional as F
6+
from typing import List
67

78

89
__all__ = [
@@ -11,6 +12,7 @@
1112
"onehot_encoding",
1213
"pseudo_residual_classification",
1314
"pseudo_residual_regression",
15+
"majority_vote",
1416
]
1517

1618

@@ -51,3 +53,23 @@ def pseudo_residual_regression(target, output):
5153
raise ValueError(msg.format(target.size(), output.size()))
5254

5355
return target - output
56+
57+
58+
def majority_vote(outputs: List[torch.Tensor]) -> torch.Tensor:
59+
"""Compute the majority vote for a list of model outputs.
60+
outputs: list of length (n_models)
61+
containing tensors with shape (n_samples, n_classes)
62+
majority_one_hots: (n_samples, n_classes)
63+
"""
64+
65+
if len(outputs[0].shape) != 2:
66+
msg = """The shape of outputs should be a list tensors of
67+
length (n_models) with sizes (n_samples, n_classes).
68+
The first tensor had shape {} """
69+
raise ValueError(msg.format(outputs[0].shape))
70+
71+
votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0]
72+
proba = torch.zeros_like(outputs[0])
73+
majority_one_hots = proba.scatter_(1, votes.view(-1, 1), 1)
74+
75+
return majority_one_hots

torchensemble/voting.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,36 @@ def _parallel_fit_per_epoch(
9292
"""Implementation on the VotingClassifier.""", "model"
9393
)
9494
class VotingClassifier(BaseClassifier):
95+
def __init__(self, voting_strategy="soft", **kwargs):
96+
super(VotingClassifier, self).__init__(**kwargs)
97+
98+
implemented_strategies = {"soft", "hard"}
99+
if voting_strategy not in implemented_strategies:
100+
msg = (
101+
"Voting strategy {} is not implemented, "
102+
"please choose from {}."
103+
)
104+
raise ValueError(
105+
msg.format(voting_strategy, implemented_strategies)
106+
)
107+
108+
self.voting_strategy = voting_strategy
109+
95110
@torchensemble_model_doc(
96111
"""Implementation on the data forwarding in VotingClassifier.""",
97112
"classifier_forward",
98113
)
99114
def forward(self, *x):
100-
# Average over class distributions from all base estimators.
115+
101116
outputs = [
102117
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
103118
]
104-
proba = op.average(outputs)
119+
120+
if self.voting_strategy == "soft":
121+
proba = op.average(outputs)
122+
123+
elif self.voting_strategy == "hard":
124+
proba = op.majority_vote(outputs)
105125

106126
return proba
107127

@@ -167,12 +187,17 @@ def fit(
167187
# Utils
168188
best_acc = 0.0
169189

170-
# Internal helper function on pesudo forward
190+
# Internal helper function on pseudo forward
171191
def _forward(estimators, *x):
172192
outputs = [
173193
F.softmax(estimator(*x), dim=1) for estimator in estimators
174194
]
175-
proba = op.average(outputs)
195+
196+
if self.voting_strategy == "soft":
197+
proba = op.average(outputs)
198+
199+
elif self.voting_strategy == "hard":
200+
proba = op.majority_vote(outputs)
176201

177202
return proba
178203

@@ -287,6 +312,11 @@ def predict(self, *x):
287312
"""Implementation on the NeuralForestClassifier.""", "tree_ensmeble_model"
288313
)
289314
class NeuralForestClassifier(BaseTreeEnsemble, VotingClassifier):
315+
def __init__(self, voting_strategy="soft", **kwargs):
316+
super().__init__(**kwargs)
317+
318+
self.voting_strategy = voting_strategy
319+
290320
@torchensemble_model_doc(
291321
"""Implementation on the data forwarding in NeuralForestClassifier.""",
292322
"classifier_forward",
@@ -420,7 +450,7 @@ def fit(
420450
# Utils
421451
best_loss = float("inf")
422452

423-
# Internal helper function on pesudo forward
453+
# Internal helper function on pseudo forward
424454
def _forward(estimators, *x):
425455
outputs = [estimator(*x) for estimator in estimators]
426456
pred = op.average(outputs)

0 commit comments

Comments
 (0)