44from collections .abc import Iterable
55from contextlib import suppress
66from functools import partial
7+ import itertools
78from operator import itemgetter
89
910import numpy as np
@@ -53,7 +54,8 @@ class BalancingLearner(BaseLearner):
5354 strategy : 'loss_improvements' (default), 'loss', or 'npoints'
5455 The points that the `BalancingLearner` choses can be either based on:
5556 the best 'loss_improvements', the smallest total 'loss' of the
56- child learners, or the number of points per learner, using 'npoints'.
57+ child learners, the number of points per learner, using 'npoints',
58+ or by cycling through the learners one by one using 'cycle'.
5759 One can dynamically change the strategy while the simulation is
5860 running by changing the ``learner.strategy`` attribute.
5961
@@ -90,10 +92,11 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
9092
9193 @property
9294 def strategy (self ):
93- """Can be either 'loss_improvements' (default), 'loss', or 'npoints'
94- The points that the `BalancingLearner` choses can be either based on:
95- the best 'loss_improvements', the smallest total 'loss' of the
96- child learners, or the number of points per learner, using 'npoints'.
95+ """Can be either 'loss_improvements' (default), 'loss', 'npoints', or
96+ 'cycle'. The points that the `BalancingLearner` choses can be either
97+ based on: the best 'loss_improvements', the smallest total 'loss' of
98+ the child learners, the number of points per learner, using 'npoints',
99+ or by going through all learners one by one using 'cycle'.
97100 One can dynamically change the strategy while the simulation is
98101 running by changing the ``learner.strategy`` attribute."""
99102 return self ._strategy
@@ -107,10 +110,12 @@ def strategy(self, strategy):
107110 self ._ask_and_tell = self ._ask_and_tell_based_on_loss
108111 elif strategy == "npoints" :
109112 self ._ask_and_tell = self ._ask_and_tell_based_on_npoints
113+ elif strategy == "cycle" :
114+ self ._ask_and_tell = self ._ask_and_tell_based_on_cycle
110115 else :
111116 raise ValueError (
112- 'Only strategy="loss_improvements", strategy="loss", or '
113- ' strategy="npoints" is implemented.'
117+ 'Only strategy="loss_improvements", strategy="loss",'
118+ ' strategy="npoints", or strategy="cycle" is implemented.'
114119 )
115120
116121 def _ask_and_tell_based_on_loss_improvements (self , n ):
@@ -173,6 +178,20 @@ def _ask_and_tell_based_on_npoints(self, n):
173178 points , loss_improvements = map (list , zip (* selected ))
174179 return points , loss_improvements
175180
181+ def _ask_and_tell_based_on_cycle (self , n ):
182+ if not hasattr (self , "_cycle" ):
183+ self ._cycle = itertools .cycle (range (len (self .learners )))
184+
185+ points , loss_improvements = [], []
186+ for _ in range (n ):
187+ index = next (self ._cycle )
188+ point , loss_improvement = self .learners [index ].ask (n = 1 )
189+ points .append ((index , point [0 ]))
190+ loss_improvements .append (loss_improvement [0 ])
191+ self .tell_pending ((index , point [0 ]))
192+
193+ return points , loss_improvements
194+
176195 def ask (self , n , tell_pending = True ):
177196 """Chose points for learners."""
178197 if n == 0 :
0 commit comments