Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

Commit 95861b2

Browse files
authored
fix error check in jax eigs (#909)
* fix error check in jax eigs * fix docstring of krylov.eigs and eigsh_lanczos * fix bug, update docstrings * fix tests * fix test
1 parent 523e7a8 commit 95861b2

File tree

4 files changed

+69
-33
lines changed

4 files changed

+69
-33
lines changed

tensornetwork/backends/jax/jax_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def A(H,x):
281281
A: A (sparse) implementation of a linear operator.
282282
Call signature of `A` is `res = A(vector, *args)`, where `vector`
283283
can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`.
284-
arsg: A list of arguments to `A`. `A` will be called as
284+
args: A list of arguments to `A`. `A` will be called as
285285
`res = A(initial_state, *args)`.
286286
initial_state: An initial vector for the algorithm. If `None`,
287287
a random initial `Tensor` is created using the `backend.randn` method

tensornetwork/backends/jax/jax_backend_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def mv(x, H):
748748
def test_eigs_eigsh_large_ncv_with_init(dtype, solver, matrix_generator,
749749
exact_decomp, which):
750750
backend = jax_backend.JaxBackend()
751-
D = 16
751+
D = 100
752752
np.random.seed(10)
753753
init = backend.randn((D,), dtype=dtype, seed=10)
754754
H = matrix_generator(backend, dtype, D)
@@ -949,7 +949,8 @@ def test_eigs_eigsh_raises(solver, whichs):
949949
def test_eigs_dtype_raises():
950950
solver = jax_backend.JaxBackend().eigs
951951
with pytest.raises(TypeError, match="dtype"):
952-
solver(lambda x: x, shape=(10,), dtype=np.int32)
952+
solver(lambda x: x, shape=(10,), dtype=np.int32,
953+
num_krylov_vecs=10)
953954

954955
##################################################################
955956
############# This test should just not crash ################

tensornetwork/backends/jax/jitted_functions.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -743,10 +743,11 @@ def implicitly_restarted_arnoldi_method(
743743

744744
dim = np.prod(shape).astype(np.int32)
745745
num_expand = num_krylov_vecs - numeig
746-
if num_krylov_vecs <= numeig < dim:
747-
raise ValueError(f"num_krylov_vecs must be between numeig <"
748-
f" num_krylov_vecs <= dim = {dim},"
749-
f" num_krylov_vecs = {num_krylov_vecs}")
746+
if not numeig <= num_krylov_vecs <= dim:
747+
raise ValueError(f"num_krylov_vecs must be between numeig <="
748+
f" num_krylov_vecs <= dim, got "
749+
f" numeig = {numeig}, num_krylov_vecs = "
750+
f"{num_krylov_vecs}, dim = {dim}.")
750751
if numeig > dim:
751752
raise ValueError(f"number of requested eigenvalues numeig = {numeig} "
752753
f"is larger than the dimension of the operator "

tensornetwork/linalg/krylov.py

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,13 @@ def eigsh_lanczos(A: Callable,
125125
"""
126126
Lanczos method for finding the lowest eigenvector-eigenvalue pairs
127127
of `A`.
128+
128129
Args:
129130
A: A (sparse) implementation of a linear operator.
130131
Call signature of `A` is `res = A(vector, *args)`, where `vector`
131132
can be an arbitrary `Array`, and `res.shape` has to be `vector.shape`.
132-
arsg: A list of arguments to `A`. `A` will be called as
133+
backend: A backend, text specifying one, or None.
134+
args: A list of arguments to `A`. `A` will be called as
133135
`res = A(x0, *args)`.
134136
x0: An initial vector for the Lanczos algorithm. If `None`,
135137
a random initial vector is created using the `backend.randn` method
@@ -152,6 +154,7 @@ def eigsh_lanczos(A: Callable,
152154
iterations to check convergence.
153155
reorthogonalize: If `True`, Krylov vectors are kept orthogonal by
154156
explicit orthogonalization (more costly than `reorthogonalize=False`)
157+
155158
Returns:
156159
(eigvals, eigvecs)
157160
eigvals: A list of `numeig` lowest eigenvalues
@@ -182,39 +185,70 @@ def eigs(A: Callable,
182185
which: Text = 'LR',
183186
maxiter: int = 20) -> Tuple[Tensor, List]:
184187
"""
185-
Lanczos method for finding the lowest eigenvector-eigenvalue pairs
186-
of `A`.
188+
Implicitly restarted Arnoldi method for finding the lowest
189+
eigenvector-eigenvalue pairs of a linear operator `A`.
190+
`A` is a function implementing the matrix-vector
191+
product.
192+
193+
WARNING: This routine uses jax.jit to reduce runtimes. jitting is triggered
194+
at the first invocation of `eigs`, and on any subsequent calls
195+
if the python `id` of `A` changes, even if the formal definition of `A`
196+
stays the same.
197+
Example: the following will jit once at the beginning, and then never again:
198+
199+
```python
200+
import jax
201+
import numpy as np
202+
def A(H,x):
203+
return jax.np.dot(H,x)
204+
for n in range(100):
205+
H = jax.np.array(np.random.rand(10,10))
206+
x = jax.np.array(np.random.rand(10,10))
207+
res = eigs(A, [H],x) #jitting is triggerd only at `n=0`
208+
```
209+
210+
The following code triggers jitting at every iteration, which
211+
results in considerably reduced performance
212+
213+
```python
214+
import jax
215+
import numpy as np
216+
for n in range(100):
217+
def A(H,x):
218+
return jax.np.dot(H,x)
219+
H = jax.np.array(np.random.rand(10,10))
220+
x = jax.np.array(np.random.rand(10,10))
221+
res = eigs(A, [H],x) #jitting is triggerd at every step `n`
222+
```
223+
187224
Args:
188225
A: A (sparse) implementation of a linear operator.
189226
Call signature of `A` is `res = A(vector, *args)`, where `vector`
190-
can be an arbitrary `Array`, and `res.shape` has to be `vector.shape`.
191-
arsg: A list of arguments to `A`. `A` will be called as
192-
`res = A(x0, *args)`.
193-
x0: An initial vector for the Lanczos algorithm. If `None`,
194-
a random initial vector is created using the `backend.randn` method
227+
can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`.
228+
backend: A backend, text specifying one, or None.
229+
args: A list of arguments to `A`. `A` will be called as
230+
`res = A(initial_state, *args)`.
231+
x0: An initial vector for the algorithm. If `None`,
232+
a random initial `Tensor` is created using the `backend.randn` method
195233
shape: The shape of the input-dimension of `A`.
196-
dtype: The dtype of the input `A`. If both no `x0` is provided,
234+
dtype: The dtype of the input `A`. If no `initial_state` is provided,
197235
a random initial state with shape `shape` and dtype `dtype` is created.
198236
num_krylov_vecs: The number of iterations (number of krylov vectors).
199-
numeig: The nummber of eigenvector-eigenvalue pairs to be computed.
200-
If `numeig > 1`, `reorthogonalize` has to be `True`.
201-
tol: The desired precision of the eigenvalus. Uses
202-
`backend.norm(eigvalsnew[0:numeig] - eigvalsold[0:numeig]) < tol`
203-
as stopping criterion between two diagonalization steps of the
204-
tridiagonal operator.
205-
delta: Stopping criterion for Lanczos iteration.
206-
If a Krylov vector :math: `x_n` has an L2 norm
207-
:math:`\\lVert x_n\\rVert < delta`, the iteration
208-
is stopped. It means that an (approximate) invariant subspace has
209-
been found.
210-
ndiag: The tridiagonal Operator is diagonalized every `ndiag`
211-
iterations to check convergence.
212-
reorthogonalize: If `True`, Krylov vectors are kept orthogonal by
213-
explicit orthogonalization (more costly than `reorthogonalize=False`)
237+
numeig: The number of eigenvector-eigenvalue pairs to be computed.
238+
tol: The desired precision of the eigenvalues. For the jax backend
239+
this has currently no effect, and precision of eigenvalues is not
240+
guaranteed. This feature may be added at a later point. To increase
241+
precision the caller can either increase `maxiter` or `num_krylov_vecs`.
242+
which: Flag for targetting different types of eigenvalues. Currently
243+
supported are `which = 'LR'` (larges real part) and `which = 'LM'`
244+
(larges magnitude).
245+
maxiter: Maximum number of restarts. For `maxiter=0` the routine becomes
246+
equivalent to a simple Arnoldi method.
247+
214248
Returns:
215249
(eigvals, eigvecs)
216-
eigvals: A list of `numeig` lowest eigenvalues
217-
eigvecs: A list of `numeig` lowest eigenvectors
250+
eigvals: A list of `numeig` eigenvalues
251+
eigvecs: A list of `numeig` eigenvectors
218252
"""
219253
backend, x0_array, args_array = krylov_error_checks(backend, x0, args)
220254
mv = KRYLOV_MATVEC_CACHE.retrieve(backend.name, A)

0 commit comments

Comments
 (0)