@@ -125,11 +125,13 @@ def eigsh_lanczos(A: Callable,
125125 """
126126 Lanczos method for finding the lowest eigenvector-eigenvalue pairs
127127 of `A`.
128+
128129 Args:
129130 A: A (sparse) implementation of a linear operator.
130131 Call signature of `A` is `res = A(vector, *args)`, where `vector`
131132 can be an arbitrary `Array`, and `res.shape` has to be `vector.shape`.
132- arsg: A list of arguments to `A`. `A` will be called as
133+ backend: A backend, text specifying one, or None.
134+ args: A list of arguments to `A`. `A` will be called as
133135 `res = A(x0, *args)`.
134136 x0: An initial vector for the Lanczos algorithm. If `None`,
135137 a random initial vector is created using the `backend.randn` method
@@ -152,6 +154,7 @@ def eigsh_lanczos(A: Callable,
152154 iterations to check convergence.
153155 reorthogonalize: If `True`, Krylov vectors are kept orthogonal by
154156 explicit orthogonalization (more costly than `reorthogonalize=False`)
157+
155158 Returns:
156159 (eigvals, eigvecs)
157160 eigvals: A list of `numeig` lowest eigenvalues
@@ -182,39 +185,70 @@ def eigs(A: Callable,
182185 which : Text = 'LR' ,
183186 maxiter : int = 20 ) -> Tuple [Tensor , List ]:
184187 """
185- Lanczos method for finding the lowest eigenvector-eigenvalue pairs
186- of `A`.
188+ Implicitly restarted Arnoldi method for finding the lowest
189+ eigenvector-eigenvalue pairs of a linear operator `A`.
190+ `A` is a function implementing the matrix-vector
191+ product.
192+
193+ WARNING: This routine uses jax.jit to reduce runtimes. jitting is triggered
194+ at the first invocation of `eigs`, and on any subsequent calls
195+ if the python `id` of `A` changes, even if the formal definition of `A`
196+ stays the same.
197+ Example: the following will jit once at the beginning, and then never again:
198+
199+ ```python
200+ import jax
201+ import numpy as np
202+ def A(H,x):
203+ return jax.np.dot(H,x)
204+ for n in range(100):
205+ H = jax.np.array(np.random.rand(10,10))
206+ x = jax.np.array(np.random.rand(10,10))
207+ res = eigs(A, [H],x) #jitting is triggerd only at `n=0`
208+ ```
209+
210+ The following code triggers jitting at every iteration, which
211+ results in considerably reduced performance
212+
213+ ```python
214+ import jax
215+ import numpy as np
216+ for n in range(100):
217+ def A(H,x):
218+ return jax.np.dot(H,x)
219+ H = jax.np.array(np.random.rand(10,10))
220+ x = jax.np.array(np.random.rand(10,10))
221+ res = eigs(A, [H],x) #jitting is triggerd at every step `n`
222+ ```
223+
187224 Args:
188225 A: A (sparse) implementation of a linear operator.
189226 Call signature of `A` is `res = A(vector, *args)`, where `vector`
190- can be an arbitrary `Array`, and `res.shape` has to be `vector.shape`.
191- arsg: A list of arguments to `A`. `A` will be called as
192- `res = A(x0, *args)`.
193- x0: An initial vector for the Lanczos algorithm. If `None`,
194- a random initial vector is created using the `backend.randn` method
227+ can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`.
228+ backend: A backend, text specifying one, or None.
229+ args: A list of arguments to `A`. `A` will be called as
230+ `res = A(initial_state, *args)`.
231+ x0: An initial vector for the algorithm. If `None`,
232+ a random initial `Tensor` is created using the `backend.randn` method
195233 shape: The shape of the input-dimension of `A`.
196- dtype: The dtype of the input `A`. If both no `x0 ` is provided,
234+ dtype: The dtype of the input `A`. If no `initial_state ` is provided,
197235 a random initial state with shape `shape` and dtype `dtype` is created.
198236 num_krylov_vecs: The number of iterations (number of krylov vectors).
199- numeig: The nummber of eigenvector-eigenvalue pairs to be computed.
200- If `numeig > 1`, `reorthogonalize` has to be `True`.
201- tol: The desired precision of the eigenvalus. Uses
202- `backend.norm(eigvalsnew[0:numeig] - eigvalsold[0:numeig]) < tol`
203- as stopping criterion between two diagonalization steps of the
204- tridiagonal operator.
205- delta: Stopping criterion for Lanczos iteration.
206- If a Krylov vector :math: `x_n` has an L2 norm
207- :math:`\\ lVert x_n\\ rVert < delta`, the iteration
208- is stopped. It means that an (approximate) invariant subspace has
209- been found.
210- ndiag: The tridiagonal Operator is diagonalized every `ndiag`
211- iterations to check convergence.
212- reorthogonalize: If `True`, Krylov vectors are kept orthogonal by
213- explicit orthogonalization (more costly than `reorthogonalize=False`)
237+ numeig: The number of eigenvector-eigenvalue pairs to be computed.
238+ tol: The desired precision of the eigenvalues. For the jax backend
239+ this has currently no effect, and precision of eigenvalues is not
240+ guaranteed. This feature may be added at a later point. To increase
241+ precision the caller can either increase `maxiter` or `num_krylov_vecs`.
242+ which: Flag for targetting different types of eigenvalues. Currently
243+ supported are `which = 'LR'` (larges real part) and `which = 'LM'`
244+ (larges magnitude).
245+ maxiter: Maximum number of restarts. For `maxiter=0` the routine becomes
246+ equivalent to a simple Arnoldi method.
247+
214248 Returns:
215249 (eigvals, eigvecs)
216- eigvals: A list of `numeig` lowest eigenvalues
217- eigvecs: A list of `numeig` lowest eigenvectors
250+ eigvals: A list of `numeig` eigenvalues
251+ eigvecs: A list of `numeig` eigenvectors
218252 """
219253 backend , x0_array , args_array = krylov_error_checks (backend , x0 , args )
220254 mv = KRYLOV_MATVEC_CACHE .retrieve (backend .name , A )
0 commit comments