@@ -92,16 +92,36 @@ def _parallel_fit_per_epoch(
9292 """Implementation on the VotingClassifier.""" , "model"
9393)
9494class VotingClassifier (BaseClassifier ):
95+ def __init__ (self , voting_strategy = "soft" , ** kwargs ):
96+ super (VotingClassifier , self ).__init__ (** kwargs )
97+
98+ implemented_strategies = {"soft" , "hard" }
99+ if voting_strategy not in implemented_strategies :
100+ msg = (
101+ "Voting strategy {} is not implemented, "
102+ "please choose from {}."
103+ )
104+ raise ValueError (
105+ msg .format (voting_strategy , implemented_strategies )
106+ )
107+
108+ self .voting_strategy = voting_strategy
109+
95110 @torchensemble_model_doc (
96111 """Implementation on the data forwarding in VotingClassifier.""" ,
97112 "classifier_forward" ,
98113 )
99114 def forward (self , * x ):
100- # Average over class distributions from all base estimators.
115+
101116 outputs = [
102117 F .softmax (estimator (* x ), dim = 1 ) for estimator in self .estimators_
103118 ]
104- proba = op .average (outputs )
119+
120+ if self .voting_strategy == "soft" :
121+ proba = op .average (outputs )
122+
123+ elif self .voting_strategy == "hard" :
124+ proba = op .majority_vote (outputs )
105125
106126 return proba
107127
@@ -167,12 +187,17 @@ def fit(
167187 # Utils
168188 best_acc = 0.0
169189
170- # Internal helper function on pesudo forward
190+ # Internal helper function on pseudo forward
171191 def _forward (estimators , * x ):
172192 outputs = [
173193 F .softmax (estimator (* x ), dim = 1 ) for estimator in estimators
174194 ]
175- proba = op .average (outputs )
195+
196+ if self .voting_strategy == "soft" :
197+ proba = op .average (outputs )
198+
199+ elif self .voting_strategy == "hard" :
200+ proba = op .majority_vote (outputs )
176201
177202 return proba
178203
@@ -287,6 +312,11 @@ def predict(self, *x):
287312 """Implementation on the NeuralForestClassifier.""" , "tree_ensmeble_model"
288313)
289314class NeuralForestClassifier (BaseTreeEnsemble , VotingClassifier ):
315+ def __init__ (self , voting_strategy = "soft" , ** kwargs ):
316+ super ().__init__ (** kwargs )
317+
318+ self .voting_strategy = voting_strategy
319+
290320 @torchensemble_model_doc (
291321 """Implementation on the data forwarding in NeuralForestClassifier.""" ,
292322 "classifier_forward" ,
@@ -420,7 +450,7 @@ def fit(
420450 # Utils
421451 best_loss = float ("inf" )
422452
423- # Internal helper function on pesudo forward
453+ # Internal helper function on pseudo forward
424454 def _forward (estimators , * x ):
425455 outputs = [estimator (* x ) for estimator in estimators ]
426456 pred = op .average (outputs )
0 commit comments