1313from functools import partial
1414from typing import Any
1515
16+ import numpy .typing as npt
17+
1618import torch
1719from botorch .optim .utils import (
1820 _handle_numerical_errors ,
2123)
2224from botorch .optim .utils .numpy_utils import as_ndarray
2325from botorch .utils .context_managers import zero_grad_ctx
24- from numpy import float64 as np_float64 , full as np_full , ndarray , zeros as np_zeros
26+ from numpy import float64 as np_float64 , full as np_full , zeros as np_zeros
2527from torch import Tensor
2628
2729
@@ -82,10 +84,10 @@ def __init__(
8284 self ,
8385 closure : Callable [[], tuple [Tensor , Sequence [Tensor | None ]]],
8486 parameters : dict [str , Tensor ],
85- as_array : Callable [[Tensor ], ndarray ] = None , # pyre-ignore [9]
86- as_tensor : Callable [[ndarray ], Tensor ] = torch .as_tensor ,
87- get_state : Callable [[], ndarray ] = None , # pyre-ignore [9]
88- set_state : Callable [[ndarray ], None ] = None , # pyre-ignore [9]
87+ as_array : Callable [[Tensor ], npt . NDArray ] = None , # pyre-ignore [9]
88+ as_tensor : Callable [[npt . NDArray ], Tensor ] = torch .as_tensor ,
89+ get_state : Callable [[], npt . NDArray ] = None , # pyre-ignore [9]
90+ set_state : Callable [[npt . NDArray ], None ] = None , # pyre-ignore [9]
8991 fill_value : float = 0.0 ,
9092 persistent : bool = True ,
9193 ) -> None :
@@ -140,11 +142,11 @@ def __init__(
140142
141143 self .fill_value = fill_value
142144 self .persistent = persistent
143- self ._gradient_ndarray : ndarray | None = None
145+ self ._gradient_ndarray : npt . NDArray | None = None
144146
145147 def __call__ (
146- self , state : ndarray | None = None , ** kwargs : Any
147- ) -> tuple [ndarray , ndarray ]:
148+ self , state : npt . NDArray | None = None , ** kwargs : Any
149+ ) -> tuple [npt . NDArray , npt . NDArray ]:
148150 if state is not None :
149151 self .state = state
150152
@@ -164,14 +166,14 @@ def __call__(
164166 return value , grads
165167
166168 @property
167- def state (self ) -> ndarray :
169+ def state (self ) -> npt . NDArray :
168170 return self ._get_state ()
169171
170172 @state .setter
171- def state (self , state : ndarray ) -> None :
173+ def state (self , state : npt . NDArray ) -> None :
172174 self ._set_state (state )
173175
174- def _get_gradient_ndarray (self , fill_value : float | None = None ) -> ndarray :
176+ def _get_gradient_ndarray (self , fill_value : float | None = None ) -> npt . NDArray :
175177 if self .persistent and self ._gradient_ndarray is not None :
176178 if fill_value is not None :
177179 self ._gradient_ndarray .fill (fill_value )
0 commit comments