Skip to content

Commit 8363aa6

Browse files
committed
Add type-hints to adaptive/learner/data_saver.py
1 parent 6b3209f commit 8363aa6

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

adaptive/learner/data_saver.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import functools
44
from collections import OrderedDict
5+
from operator import itemgetter
6+
from typing import Any
57

68
from adaptive.learner.base_learner import BaseLearner
79
from adaptive.utils import copy_docstring_from
@@ -39,7 +41,7 @@ class DataSaver:
3941
>>> learner = DataSaver(_learner, arg_picker=itemgetter('y'))
4042
"""
4143

42-
def __init__(self, learner, arg_picker):
44+
def __init__(self, learner: BaseLearner, arg_picker: itemgetter) -> None:
4345
self.learner = learner
4446
self.extra_data = OrderedDict()
4547
self.function = learner.function
@@ -49,7 +51,7 @@ def new(self) -> DataSaver:
4951
"""Return a new `DataSaver` with the same `arg_picker` and `learner`."""
5052
return DataSaver(self.learner.new(), self.arg_picker)
5153

52-
def __getattr__(self, attr):
54+
def __getattr__(self, attr: str) -> Any:
5355
return getattr(self.learner, attr)
5456

5557
@copy_docstring_from(BaseLearner.tell)
@@ -122,10 +124,17 @@ def load_dataframe(
122124
key = _to_key(x[:-1])
123125
self.extra_data[key] = x[-1]
124126

125-
def _get_data(self):
127+
def _get_data(self) -> tuple[Any, OrderedDict]:
126128
return self.learner._get_data(), self.extra_data
127129

128-
def _set_data(self, data):
130+
def _set_data(
131+
self,
132+
data: (
133+
tuple[OrderedDict, OrderedDict]
134+
| tuple[dict[int | float, float], OrderedDict]
135+
| tuple[tuple[dict[int, float], int, float, float], OrderedDict]
136+
),
137+
) -> None:
129138
learner_data, self.extra_data = data
130139
self.learner._set_data(learner_data)
131140

0 commit comments

Comments
 (0)