22
33import functools
44from collections import OrderedDict
5+ from typing import Any , Callable
56
67from adaptive .learner .base_learner import BaseLearner
78from adaptive .utils import copy_docstring_from
@@ -39,7 +40,7 @@ class DataSaver:
3940 >>> learner = DataSaver(_learner, arg_picker=itemgetter('y'))
4041 """
4142
42- def __init__ (self , learner , arg_picker ) :
43+ def __init__ (self , learner : BaseLearner , arg_picker : Callable ) -> None :
4344 self .learner = learner
4445 self .extra_data = OrderedDict ()
4546 self .function = learner .function
@@ -49,21 +50,21 @@ def new(self) -> DataSaver:
4950 """Return a new `DataSaver` with the same `arg_picker` and `learner`."""
5051 return DataSaver (self .learner .new (), self .arg_picker )
5152
52- def __getattr__ (self , attr ) :
53+ def __getattr__ (self , attr : str ) -> Any :
5354 return getattr (self .learner , attr )
5455
5556 @copy_docstring_from (BaseLearner .tell )
56- def tell (self , x , result ) :
57+ def tell (self , x : Any , result : Any ) -> None :
5758 y = self .arg_picker (result )
5859 self .extra_data [x ] = result
5960 self .learner .tell (x , y )
6061
6162 @copy_docstring_from (BaseLearner .tell_pending )
62- def tell_pending (self , x ) :
63+ def tell_pending (self , x : Any ) -> None :
6364 self .learner .tell_pending (x )
6465
6566 def to_dataframe (
66- self , extra_data_name : str = "extra_data" , ** kwargs
67+ self , extra_data_name : str = "extra_data" , ** kwargs : Any
6768 ) -> pandas .DataFrame :
6869 """Return the data as a concatenated `pandas.DataFrame` from child learners.
6970
@@ -98,7 +99,7 @@ def load_dataframe(
9899 extra_data_name : str = "extra_data" ,
99100 input_names : tuple [str ] = (),
100101 ** kwargs ,
101- ):
102+ ) -> None :
102103 """Load the data from a `pandas.DataFrame` into the learner.
103104
104105 Parameters
@@ -122,33 +123,36 @@ def load_dataframe(
122123 key = _to_key (x [:- 1 ])
123124 self .extra_data [key ] = x [- 1 ]
124125
125- def _get_data (self ):
126+ def _get_data (self ) -> tuple [ Any , OrderedDict ] :
126127 return self .learner ._get_data (), self .extra_data
127128
128- def _set_data (self , data ):
129+ def _set_data (
130+ self ,
131+ data : tuple [Any , OrderedDict ],
132+ ) -> None :
129133 learner_data , self .extra_data = data
130134 self .learner ._set_data (learner_data )
131135
132- def __getstate__ (self ):
136+ def __getstate__ (self ) -> tuple [ BaseLearner , Callable , OrderedDict ] :
133137 return (
134138 self .learner ,
135139 self .arg_picker ,
136140 self .extra_data ,
137141 )
138142
139- def __setstate__ (self , state ) :
143+ def __setstate__ (self , state : tuple [ BaseLearner , Callable , OrderedDict ]) -> None :
140144 learner , arg_picker , extra_data = state
141145 self .__init__ (learner , arg_picker )
142146 self .extra_data = extra_data
143147
144148 @copy_docstring_from (BaseLearner .save )
145- def save (self , fname , compress = True ):
149+ def save (self , fname , compress = True ) -> None :
146150 # We copy this method because the 'DataSaver' is not a
147151 # subclass of the 'BaseLearner'.
148152 BaseLearner .save (self , fname , compress )
149153
150154 @copy_docstring_from (BaseLearner .load )
151- def load (self , fname , compress = True ):
155+ def load (self , fname , compress = True ) -> None :
152156 # We copy this method because the 'DataSaver' is not a
153157 # subclass of the 'BaseLearner'.
154158 BaseLearner .load (self , fname , compress )
0 commit comments