11from __future__ import annotations
22
33import collections
4+ from typing import Callable
45
56import numpy as np
67from skopt import Optimizer
@@ -25,8 +26,8 @@ class SKOptLearner(Optimizer, BaseLearner):
2526 Arguments to pass to ``skopt.Optimizer``.
2627 """
2728
28- def __init__ (self , function , ** kwargs ):
29- self .function = function
29+ def __init__ (self , function : Callable , ** kwargs ) -> None :
30+ self .function = function # type: ignore
3031 self .pending_points = set ()
3132 self .data = collections .OrderedDict ()
3233 self ._kwargs = kwargs
@@ -36,7 +37,7 @@ def new(self) -> SKOptLearner:
3637 """Return a new `~adaptive.SKOptLearner` without the data."""
3738 return SKOptLearner (self .function , ** self ._kwargs )
3839
39- def tell (self , x , y , fit = True ):
40+ def tell (self , x : float | list [ float ] , y : float , fit : bool = True ) -> None :
4041 if isinstance (x , collections .abc .Iterable ):
4142 self .pending_points .discard (tuple (x ))
4243 self .data [tuple (x )] = y
@@ -55,7 +56,7 @@ def remove_unfinished(self):
5556 pass
5657
5758 @cache_latest
58- def loss (self , real = True ):
59+ def loss (self , real : bool = True ) -> float :
5960 if not self .models :
6061 return np .inf
6162 else :
@@ -65,7 +66,12 @@ def loss(self, real=True):
6566 # estimator of loss, but it is the cheapest.
6667 return 1 - model .score (self .Xi , self .yi )
6768
68- def ask (self , n , tell_pending = True ):
69+ def ask (
70+ self , n : int , tell_pending : bool = True
71+ ) -> (
72+ tuple [list [float ], list [float ]]
73+ | tuple [list [list [float ]], list [float ]] # XXX: this indicates a bug!
74+ ):
6975 if not tell_pending :
7076 raise NotImplementedError (
7177 "Asking points is an irreversible "
@@ -79,7 +85,7 @@ def ask(self, n, tell_pending=True):
7985 return [p [0 ] for p in points ], [self .loss () / n ] * n
8086
8187 @property
82- def npoints (self ):
88+ def npoints (self ) -> int :
8389 """Number of evaluated points."""
8490 return len (self .Xi )
8591
0 commit comments