22
33import functools
44from collections import OrderedDict
5+ from operator import itemgetter
6+ from typing import Any
57
68from adaptive .learner .base_learner import BaseLearner
79from adaptive .utils import copy_docstring_from
@@ -39,7 +41,7 @@ class DataSaver:
3941 >>> learner = DataSaver(_learner, arg_picker=itemgetter('y'))
4042 """
4143
42- def __init__ (self , learner , arg_picker ) :
44+ def __init__ (self , learner : BaseLearner , arg_picker : itemgetter ) -> None :
4345 self .learner = learner
4446 self .extra_data = OrderedDict ()
4547 self .function = learner .function
@@ -49,7 +51,7 @@ def new(self) -> DataSaver:
4951 """Return a new `DataSaver` with the same `arg_picker` and `learner`."""
5052 return DataSaver (self .learner .new (), self .arg_picker )
5153
52- def __getattr__ (self , attr ) :
54+ def __getattr__ (self , attr : str ) -> Any :
5355 return getattr (self .learner , attr )
5456
5557 @copy_docstring_from (BaseLearner .tell )
@@ -122,10 +124,17 @@ def load_dataframe(
122124 key = _to_key (x [:- 1 ])
123125 self .extra_data [key ] = x [- 1 ]
124126
125- def _get_data (self ):
127+ def _get_data (self ) -> tuple [ Any , OrderedDict ] :
126128 return self .learner ._get_data (), self .extra_data
127129
128- def _set_data (self , data ):
130+ def _set_data (
131+ self ,
132+ data : (
133+ tuple [OrderedDict , OrderedDict ]
134+ | tuple [dict [int | float , float ], OrderedDict ]
135+ | tuple [tuple [dict [int , float ], int , float , float ], OrderedDict ]
136+ ),
137+ ) -> None :
129138 learner_data , self .extra_data = data
130139 self .learner ._set_data (learner_data )
131140
0 commit comments