11from copy import copy
2- from functools import partial
3- from typing import Any , List , Tuple , Union
2+ from typing import Any , Callable , List , Sequence , Tuple , Union
43
54import numpy as np
65from sortedcontainers import SortedDict , SortedSet
@@ -18,22 +17,19 @@ class _IgnoreFirstArgument:
1817 pickable.
1918 """
2019
21- def __init__ (self , function : partial ) -> None :
20+ def __init__ (self , function : Callable ) -> None :
2221 self .function = function
2322
2423 def __call__ (
25- self ,
26- index_point : Union [Tuple [int , int ], Tuple [int , float ], Tuple [int , np .ndarray ]],
27- * args ,
28- ** kwargs
24+ self , index_point : Tuple [int , Union [float , np .ndarray ]], * args , ** kwargs
2925 ) -> float :
3026 index , point = index_point
3127 return self .function (point , * args , ** kwargs )
3228
33- def __getstate__ (self ) -> partial :
29+ def __getstate__ (self ) -> Callable :
3430 return self .function
3531
36- def __setstate__ (self , function : partial ) -> None :
32+ def __setstate__ (self , function : Callable ) -> None :
3733 self .__init__ (function )
3834
3935
@@ -64,7 +60,7 @@ class SequenceLearner(BaseLearner):
6460 the added benefit of having results in the local kernel already.
6561 """
6662
67- def __init__ (self , function : partial , sequence : Union [ range , np . ndarray ] ) -> None :
63+ def __init__ (self , function : Callable , sequence : Sequence ) -> None :
6864 self ._original_function = function
6965 self .function = _IgnoreFirstArgument (function )
7066 self ._to_do_indices = SortedSet ({i for i , _ in enumerate (sequence )})
@@ -73,13 +69,7 @@ def __init__(self, function: partial, sequence: Union[range, np.ndarray]) -> Non
7369 self .data = SortedDict ()
7470 self .pending_points = set ()
7571
76- def ask (
77- self , n : int , tell_pending : bool = True
78- ) -> Union [
79- Tuple [List [Tuple [int , float ]], List [float ]],
80- Tuple [List [Tuple [int , int ]], List [float ]],
81- Tuple [List [Tuple [int , np .ndarray ]], List [float ]],
82- ]:
72+ def ask (self , n : int , tell_pending : bool = True ) -> Tuple [Any , List [float ]]:
8373 indices = []
8474 points = []
8575 loss_improvements = []
@@ -119,16 +109,7 @@ def remove_unfinished(self):
119109 self ._to_do_indices .add (i )
120110 self .pending_points = set ()
121111
122- def tell (
123- self ,
124- point : Union [
125- Tuple [int , int ],
126- Tuple [int , float ],
127- Tuple [int , np .ndarray ],
128- Tuple [int , None ],
129- ],
130- value : float ,
131- ) -> None :
112+ def tell (self , point : Tuple [int , Any ], value : Any ,) -> None :
132113 index , point = point
133114 self .data [index ] = value
134115 self .pending_points .discard (index )
0 commit comments