44from collections .abc import Iterable
55from contextlib import suppress
66from functools import partial
7+ import itertools
78from operator import itemgetter
89
910import numpy as np
@@ -50,10 +51,11 @@ class BalancingLearner(BaseLearner):
5051 function : callable
5152 A function that calls the functions of the underlying learners.
5253 Its signature is ``function(learner_index, point)``.
53- strategy : 'loss_improvements' (default), 'loss', or 'npoints'
54+ strategy : 'loss_improvements' (default), 'loss', 'npoints', or 'cycle'.
5455 The points that the `BalancingLearner` choses can be either based on:
5556 the best 'loss_improvements', the smallest total 'loss' of the
56- child learners, or the number of points per learner, using 'npoints'.
57+ child learners, the number of points per learner, using 'npoints',
58+ or by cycling through the learners one by one using 'cycle'.
5759 One can dynamically change the strategy while the simulation is
5860 running by changing the ``learner.strategy`` attribute.
5961
@@ -90,10 +92,11 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
9092
9193 @property
9294 def strategy (self ):
93- """Can be either 'loss_improvements' (default), 'loss', or 'npoints'
94- The points that the `BalancingLearner` choses can be either based on:
95- the best 'loss_improvements', the smallest total 'loss' of the
96- child learners, or the number of points per learner, using 'npoints'.
95+ """Can be either 'loss_improvements' (default), 'loss', 'npoints', or
96+ 'cycle'. The points that the `BalancingLearner` choses can be either
97+ based on: the best 'loss_improvements', the smallest total 'loss' of
98+ the child learners, the number of points per learner, using 'npoints',
99+ or by going through all learners one by one using 'cycle'.
97100 One can dynamically change the strategy while the simulation is
98101 running by changing the ``learner.strategy`` attribute."""
99102 return self ._strategy
@@ -107,10 +110,13 @@ def strategy(self, strategy):
107110 self ._ask_and_tell = self ._ask_and_tell_based_on_loss
108111 elif strategy == "npoints" :
109112 self ._ask_and_tell = self ._ask_and_tell_based_on_npoints
113+ elif strategy == "cycle" :
114+ self ._ask_and_tell = self ._ask_and_tell_based_on_cycle
115+ self ._cycle = itertools .cycle (range (len (self .learners )))
110116 else :
111117 raise ValueError (
112- 'Only strategy="loss_improvements", strategy="loss", or '
113- ' strategy="npoints" is implemented.'
118+ 'Only strategy="loss_improvements", strategy="loss",'
119+ ' strategy="npoints", or strategy="cycle" is implemented.'
114120 )
115121
116122 def _ask_and_tell_based_on_loss_improvements (self , n ):
@@ -173,6 +179,17 @@ def _ask_and_tell_based_on_npoints(self, n):
173179 points , loss_improvements = map (list , zip (* selected ))
174180 return points , loss_improvements
175181
182+ def _ask_and_tell_based_on_cycle (self , n ):
183+ points , loss_improvements = [], []
184+ for _ in range (n ):
185+ index = next (self ._cycle )
186+ point , loss_improvement = self .learners [index ].ask (n = 1 )
187+ points .append ((index , point [0 ]))
188+ loss_improvements .append (loss_improvement [0 ])
189+ self .tell_pending ((index , point [0 ]))
190+
191+ return points , loss_improvements
192+
176193 def ask (self , n , tell_pending = True ):
177194 """Chose points for learners."""
178195 if n == 0 :
0 commit comments