@@ -51,7 +51,7 @@ class BalancingLearner(BaseLearner):
5151 function : callable
5252 A function that calls the functions of the underlying learners.
5353 Its signature is ``function(learner_index, point)``.
54- strategy : 'loss_improvements' (default), 'loss', or 'npoints'
54+ strategy : 'loss_improvements' (default), 'loss', 'npoints', or 'cycle'.
5555 The points that the `BalancingLearner` choses can be either based on:
5656 the best 'loss_improvements', the smallest total 'loss' of the
5757 child learners, the number of points per learner, using 'npoints',
@@ -112,6 +112,7 @@ def strategy(self, strategy):
112112 self ._ask_and_tell = self ._ask_and_tell_based_on_npoints
113113 elif strategy == "cycle" :
114114 self ._ask_and_tell = self ._ask_and_tell_based_on_cycle
115+ self ._cycle = itertools .cycle (range (len (self .learners )))
115116 else :
116117 raise ValueError (
117118 'Only strategy="loss_improvements", strategy="loss",'
@@ -179,9 +180,6 @@ def _ask_and_tell_based_on_npoints(self, n):
179180 return points , loss_improvements
180181
181182 def _ask_and_tell_based_on_cycle (self , n ):
182- if not hasattr (self , "_cycle" ):
183- self ._cycle = itertools .cycle (range (len (self .learners )))
184-
185183 points , loss_improvements = [], []
186184 for _ in range (n ):
187185 index = next (self ._cycle )
0 commit comments