11import sys
2- import warnings
32from copy import copy
43
4+ from sortedcontainers import SortedSet
5+
56from adaptive .learner .base_learner import BaseLearner
67
78inf = sys .float_info .max
89
910
10- def ensure_hashable (x ):
11- try :
12- hash (x )
13- return x
14- except TypeError :
15- msg = "The items in `sequence` need to be hashable, {}. Make sure you reflect this in your function."
16- if isinstance (x , dict ):
17- warnings .warn (msg .format ("we converted `dict` to `tuple(dict.items())`" ))
18- return tuple (x .items ())
19- else :
20- warnings .warn (msg .format ("we tried to cast the items to a tuple" ))
21- return tuple (x )
11+ class _IgnoreFirstArgument :
12+ """Remove the first argument from the call signature.
2213
14+ The SequenceLearner's function receives a tuple ``(index, point)``
15+ but the original function only takes ``point``.
2316
24- class SequenceLearner (BaseLearner ):
25- def __init__ (self , function , sequence ):
17+ This is the same as `lambda x: function(x[1])`, however, that is not
18+ pickable.
19+ """
20+
21+ def __init__ (self , function ):
2622 self .function = function
2723
28- # We use a poor man's OrderedSet, a dict that points to None.
29- self ._to_do_seq = {ensure_hashable (x ): None for x in sequence }
24+ def __call__ (self , index_point , * args , ** kwargs ):
25+ index , point = index_point
26+ return self .function (point , * args , ** kwargs )
27+
28+ def __getstate__ (self ):
29+ return self .function
30+
31+ def __setstate__ (self , function ):
32+ self .__init__ (function )
33+
34+
35+ class SequenceLearner (BaseLearner ):
36+ def __init__ (self , function , sequence ):
37+ self ._original_function = function
38+ self .function = _IgnoreFirstArgument (function )
39+ self ._to_do_indices = SortedSet ({i for i , _ in enumerate (sequence )})
3040 self ._ntotal = len (sequence )
3141 self .sequence = copy (sequence )
3242 self .data = {}
3343 self .pending_points = set ()
3444
3545 def ask (self , n , tell_pending = True ):
46+ indices = []
3647 points = []
3748 loss_improvements = []
38- for point in self ._to_do_seq :
49+ for index in self ._to_do_indices :
3950 if len (points ) >= n :
4051 break
41- points .append (point )
52+ point = self .sequence [index ]
53+ indices .append (index )
54+ points .append ((index , point ))
4255 loss_improvements .append (1 / self ._ntotal )
4356
4457 if tell_pending :
45- for p in points :
46- self .tell_pending (p )
58+ for i , p in zip ( indices , points ) :
59+ self .tell_pending (( i , p ) )
4760
4861 return points , loss_improvements
4962
@@ -55,34 +68,36 @@ def _set_data(self, data):
5568 self .tell_many (* zip (* data .items ()))
5669
5770 def loss (self , real = True ):
58- if not (self ._to_do_seq or self .pending_points ):
71+ if not (self ._to_do_indices or self .pending_points ):
5972 return 0
6073 else :
6174 npoints = self .npoints + (0 if real else len (self .pending_points ))
6275 return (self ._ntotal - npoints ) / self ._ntotal
6376
6477 def remove_unfinished (self ):
65- for p in self .pending_points :
66- self ._to_do_seq [ p ] = None
78+ for i in self .pending_points :
79+ self ._to_do_indices . add ( i )
6780 self .pending_points = set ()
6881
6982 def tell (self , point , value ):
70- self .data [point ] = value
71- self .pending_points .discard (point )
72- self ._to_do_seq .pop (point , None )
83+ index , point = point
84+ self .data [index ] = value
85+ self .pending_points .discard (index )
86+ self ._to_do_indices .discard (index )
7387
7488 def tell_pending (self , point ):
75- self .pending_points .add (point )
76- self ._to_do_seq .pop (point , None )
89+ index , point = point
90+ self .pending_points .add (index )
91+ self ._to_do_indices .discard (index )
7792
7893 def done (self ):
79- return not self ._to_do_seq and not self .pending_points
94+ return not self ._to_do_indices and not self .pending_points
8095
8196 def result (self ):
8297 """Get back the data in the same order as ``sequence``."""
8398 if not self .done ():
8499 raise Exception ("Learner is not yet complete." )
85- return [self .data [ensure_hashable ( x ) ] for x in self .sequence ]
100+ return [self .data [i ] for i , _ in enumerate ( self .sequence ) ]
86101
87102 @property
88103 def npoints (self ):
0 commit comments