Skip to content

Commit 9490093

Browse files
committed
Add type-hints to adaptive/learner/balancing_learner.py
1 parent 1786f80 commit 9490093

File tree

1 file changed

+77
-28
lines changed

1 file changed

+77
-28
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,46 @@
11
from __future__ import annotations
22

33
import itertools
4+
import numbers
45
from collections import defaultdict
56
from collections.abc import Iterable
67
from contextlib import suppress
78
from functools import partial
89
from operator import itemgetter
10+
from typing import Any, Callable, Dict, Sequence, Tuple, Union
911

1012
import numpy as np
1113

1214
from adaptive.learner.base_learner import BaseLearner
1315
from adaptive.notebook_integration import ensure_holoviews
1416
from adaptive.utils import cache_latest, named_product, restore
1517

18+
try:
19+
from typing import Literal, TypeAlias
20+
except ImportError:
21+
from typing_extensions import Literal, TypeAlias
22+
1623
try:
1724
import pandas
1825

1926
with_pandas = True
20-
2127
except ModuleNotFoundError:
2228
with_pandas = False
2329

2430

25-
def dispatch(child_functions, arg):
31+
def dispatch(child_functions: list[Callable], arg: Any) -> Any:
2632
index, x = arg
2733
return child_functions[index](x)
2834

2935

36+
STRATEGY_TYPE: TypeAlias = Literal["loss_improvements", "loss", "npoints", "cycle"]
37+
38+
CDIMS_TYPE: TypeAlias = Union[
39+
Sequence[Dict[str, Any]],
40+
Tuple[Sequence[str], Sequence[Tuple[Any, ...]]],
41+
]
42+
43+
3044
class BalancingLearner(BaseLearner):
3145
r"""Choose the optimal points from a set of learners.
3246
@@ -78,13 +92,19 @@ class BalancingLearner(BaseLearner):
7892
behave in an undefined way. Change the `strategy` in that case.
7993
"""
8094

81-
def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
95+
def __init__(
96+
self,
97+
learners: list[BaseLearner],
98+
*,
99+
cdims: CDIMS_TYPE | None = None,
100+
strategy: STRATEGY_TYPE = "loss_improvements",
101+
) -> None:
82102
self.learners = learners
83103

84104
# Naively we would make 'function' a method, but this causes problems
85105
# when using executors from 'concurrent.futures' because we have to
86106
# pickle the whole learner.
87-
self.function = partial(dispatch, [l.function for l in self.learners])
107+
self.function = partial(dispatch, [l.function for l in self.learners]) # type: ignore
88108

89109
self._ask_cache = {}
90110
self._loss = {}
@@ -96,7 +116,7 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
96116
"A BalacingLearner can handle only one type" " of learners."
97117
)
98118

99-
self.strategy = strategy
119+
self.strategy: STRATEGY_TYPE = strategy
100120

101121
def new(self) -> BalancingLearner:
102122
"""Create a new `BalancingLearner` with the same parameters."""
@@ -107,21 +127,21 @@ def new(self) -> BalancingLearner:
107127
)
108128

109129
@property
110-
def data(self):
130+
def data(self) -> dict[tuple[int, Any], Any]:
111131
data = {}
112132
for i, l in enumerate(self.learners):
113133
data.update({(i, p): v for p, v in l.data.items()})
114134
return data
115135

116136
@property
117-
def pending_points(self):
137+
def pending_points(self) -> set[tuple[int, Any]]:
118138
pending_points = set()
119139
for i, l in enumerate(self.learners):
120140
pending_points.update({(i, p) for p in l.pending_points})
121141
return pending_points
122142

123143
@property
124-
def npoints(self):
144+
def npoints(self) -> int:
125145
return sum(l.npoints for l in self.learners)
126146

127147
@property
@@ -134,7 +154,7 @@ def nsamples(self):
134154
)
135155

136156
@property
137-
def strategy(self):
157+
def strategy(self) -> STRATEGY_TYPE:
138158
"""Can be either 'loss_improvements' (default), 'loss', 'npoints', or
139159
'cycle'. The points that the `BalancingLearner` choses can be either
140160
based on: the best 'loss_improvements', the smallest total 'loss' of
@@ -145,7 +165,7 @@ def strategy(self):
145165
return self._strategy
146166

147167
@strategy.setter
148-
def strategy(self, strategy):
168+
def strategy(self, strategy: STRATEGY_TYPE) -> None:
149169
self._strategy = strategy
150170
if strategy == "loss_improvements":
151171
self._ask_and_tell = self._ask_and_tell_based_on_loss_improvements
@@ -162,7 +182,9 @@ def strategy(self, strategy):
162182
' strategy="npoints", or strategy="cycle" is implemented.'
163183
)
164184

165-
def _ask_and_tell_based_on_loss_improvements(self, n):
185+
def _ask_and_tell_based_on_loss_improvements(
186+
self, n: int
187+
) -> tuple[list[tuple[int, Any]], list[float]]:
166188
selected = [] # tuples ((learner_index, point), loss_improvement)
167189
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
168190
for _ in range(n):
@@ -185,7 +207,9 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
185207
points, loss_improvements = map(list, zip(*selected))
186208
return points, loss_improvements
187209

188-
def _ask_and_tell_based_on_loss(self, n):
210+
def _ask_and_tell_based_on_loss(
211+
self, n: int
212+
) -> tuple[list[tuple[int, Any]], list[float]]:
189213
selected = [] # tuples ((learner_index, point), loss_improvement)
190214
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
191215
for _ in range(n):
@@ -206,7 +230,9 @@ def _ask_and_tell_based_on_loss(self, n):
206230
points, loss_improvements = map(list, zip(*selected))
207231
return points, loss_improvements
208232

209-
def _ask_and_tell_based_on_npoints(self, n):
233+
def _ask_and_tell_based_on_npoints(
234+
self, n: numbers.Integral
235+
) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]:
210236
selected = [] # tuples ((learner_index, point), loss_improvement)
211237
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
212238
for _ in range(n):
@@ -222,7 +248,9 @@ def _ask_and_tell_based_on_npoints(self, n):
222248
points, loss_improvements = map(list, zip(*selected))
223249
return points, loss_improvements
224250

225-
def _ask_and_tell_based_on_cycle(self, n):
251+
def _ask_and_tell_based_on_cycle(
252+
self, n: int
253+
) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]:
226254
points, loss_improvements = [], []
227255
for _ in range(n):
228256
index = next(self._cycle)
@@ -233,7 +261,9 @@ def _ask_and_tell_based_on_cycle(self, n):
233261

234262
return points, loss_improvements
235263

236-
def ask(self, n, tell_pending=True):
264+
def ask(
265+
self, n: int, tell_pending: bool = True
266+
) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]:
237267
"""Chose points for learners."""
238268
if n == 0:
239269
return [], []
@@ -244,20 +274,20 @@ def ask(self, n, tell_pending=True):
244274
else:
245275
return self._ask_and_tell(n)
246276

247-
def tell(self, x, y):
277+
def tell(self, x: tuple[numbers.Integral, Any], y: Any) -> None:
248278
index, x = x
249279
self._ask_cache.pop(index, None)
250280
self._loss.pop(index, None)
251281
self._pending_loss.pop(index, None)
252282
self.learners[index].tell(x, y)
253283

254-
def tell_pending(self, x):
284+
def tell_pending(self, x: tuple[numbers.Integral, Any]) -> None:
255285
index, x = x
256286
self._ask_cache.pop(index, None)
257287
self._loss.pop(index, None)
258288
self.learners[index].tell_pending(x)
259289

260-
def _losses(self, real=True):
290+
def _losses(self, real: bool = True) -> list[float]:
261291
losses = []
262292
loss_dict = self._loss if real else self._pending_loss
263293

@@ -269,11 +299,16 @@ def _losses(self, real=True):
269299
return losses
270300

271301
@cache_latest
272-
def loss(self, real=True):
302+
def loss(self, real: bool = True) -> float:
273303
losses = self._losses(real)
274304
return max(losses)
275305

276-
def plot(self, cdims=None, plotter=None, dynamic=True):
306+
def plot(
307+
self,
308+
cdims: CDIMS_TYPE | None = None,
309+
plotter: Callable[[BaseLearner], Any] | None = None,
310+
dynamic: bool = True,
311+
):
277312
"""Returns a DynamicMap with sliders.
278313
279314
Parameters
@@ -346,13 +381,19 @@ def plot_function(*args):
346381
vals = {d.name: d.values for d in dm.dimensions() if d.values}
347382
return hv.HoloMap(dm.select(**vals))
348383

349-
def remove_unfinished(self):
384+
def remove_unfinished(self) -> None:
350385
"""Remove uncomputed data from the learners."""
351386
for learner in self.learners:
352387
learner.remove_unfinished()
353388

354389
@classmethod
355-
def from_product(cls, f, learner_type, learner_kwargs, combos):
390+
def from_product(
391+
cls,
392+
f,
393+
learner_type: BaseLearner,
394+
learner_kwargs: dict[str, Any],
395+
combos: dict[str, Sequence[Any]],
396+
) -> BalancingLearner:
356397
"""Create a `BalancingLearner` with learners of all combinations of
357398
named variables’ values. The `cdims` will be set correctly, so calling
358399
`learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
@@ -448,7 +489,11 @@ def load_dataframe(
448489
for i, gr in df.groupby(index_name):
449490
self.learners[i].load_dataframe(gr, **kwargs)
450491

451-
def save(self, fname, compress=True):
492+
def save(
493+
self,
494+
fname: Callable[[BaseLearner], str] | Sequence[str],
495+
compress: bool = True,
496+
) -> None:
452497
"""Save the data of the child learners into pickle files
453498
in a directory.
454499
@@ -486,7 +531,11 @@ def save(self, fname, compress=True):
486531
for l in self.learners:
487532
l.save(fname(l), compress=compress)
488533

489-
def load(self, fname, compress=True):
534+
def load(
535+
self,
536+
fname: Callable[[BaseLearner], str] | Sequence[str],
537+
compress: bool = True,
538+
) -> None:
490539
"""Load the data of the child learners from pickle files
491540
in a directory.
492541
@@ -510,20 +559,20 @@ def load(self, fname, compress=True):
510559
for l in self.learners:
511560
l.load(fname(l), compress=compress)
512561

513-
def _get_data(self):
562+
def _get_data(self) -> list[Any]:
514563
return [l._get_data() for l in self.learners]
515564

516-
def _set_data(self, data):
565+
def _set_data(self, data: list[Any]):
517566
for l, _data in zip(self.learners, data):
518567
l._set_data(_data)
519568

520-
def __getstate__(self):
569+
def __getstate__(self) -> tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]:
521570
return (
522571
self.learners,
523572
self._cdims_default,
524573
self.strategy,
525574
)
526575

527-
def __setstate__(self, state):
576+
def __setstate__(self, state: tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]):
528577
learners, cdims, strategy = state
529578
self.__init__(learners, cdims=cdims, strategy=strategy)

0 commit comments

Comments
 (0)