|
1 | 1 | import jax.numpy as jnp |
2 | 2 | from jax import value_and_grad |
3 | | -from jax.util import safe_zip |
4 | 3 |
|
5 | 4 | from varipeps import varipeps_config |
6 | 5 | from varipeps.peps import PEPS_Unit_Cell |
@@ -39,7 +38,7 @@ def _map_tensors( |
39 | 38 | old_tensors = unitcell.get_unique_tensors() |
40 | 39 | if not all( |
41 | 40 | jnp.allclose(ti, tj_obj.tensor) |
42 | | - for ti, tj_obj in safe_zip(peps_tensors, old_tensors) |
| 41 | + for ti, tj_obj in zip(peps_tensors, old_tensors, strict=True) |
43 | 42 | ): |
44 | 43 | raise ValueError( |
45 | 44 | "Input tensors and provided unitcell are not the same state." |
@@ -110,14 +109,14 @@ def calc_ctmrg_expectation( |
110 | 109 | spiral_vectors_x = (spiral_vectors_x,) |
111 | 110 | spiral_vectors = tuple( |
112 | 111 | jnp.array((sx, sy)) |
113 | | - for sx, sy in safe_zip(spiral_vectors_x, spiral_vectors) |
| 112 | + for sx, sy in zip(spiral_vectors_x, spiral_vectors, strict=True) |
114 | 113 | ) |
115 | 114 | elif spiral_vectors_y is not None: |
116 | 115 | if isinstance(spiral_vectors_y, jnp.ndarray): |
117 | 116 | spiral_vectors_y = (spiral_vectors_y,) |
118 | 117 | spiral_vectors = tuple( |
119 | 118 | jnp.array((sx, sy)) |
120 | | - for sx, sy in safe_zip(spiral_vectors, spiral_vectors_y) |
| 119 | + for sx, sy in zip(spiral_vectors, spiral_vectors_y, strict=True) |
121 | 120 | ) |
122 | 121 | else: |
123 | 122 | peps_tensors, unitcell = _map_tensors( |
@@ -211,14 +210,14 @@ def calc_preconverged_ctmrg_value_and_grad( |
211 | 210 | spiral_vectors_x = (spiral_vectors_x,) |
212 | 211 | spiral_vectors = tuple( |
213 | 212 | jnp.array((sx, sy)) |
214 | | - for sx, sy in safe_zip(spiral_vectors_x, spiral_vectors) |
| 213 | + for sx, sy in zip(spiral_vectors_x, spiral_vectors, strict=True) |
215 | 214 | ) |
216 | 215 | elif spiral_vectors_y is not None: |
217 | 216 | if isinstance(spiral_vectors_y, jnp.ndarray): |
218 | 217 | spiral_vectors_y = (spiral_vectors_y,) |
219 | 218 | spiral_vectors = tuple( |
220 | 219 | jnp.array((sx, sy)) |
221 | | - for sx, sy in safe_zip(spiral_vectors, spiral_vectors_y) |
| 220 | + for sx, sy in zip(spiral_vectors, spiral_vectors_y, strict=True) |
222 | 221 | ) |
223 | 222 | else: |
224 | 223 | peps_tensors, unitcell = _map_tensors( |
@@ -293,14 +292,14 @@ def calc_ctmrg_expectation_custom( |
293 | 292 | spiral_vectors_x = (spiral_vectors_x,) |
294 | 293 | spiral_vectors = tuple( |
295 | 294 | jnp.array((sx, sy)) |
296 | | - for sx, sy in safe_zip(spiral_vectors_x, spiral_vectors) |
| 295 | + for sx, sy in zip(spiral_vectors_x, spiral_vectors, strict=True) |
297 | 296 | ) |
298 | 297 | elif spiral_vectors_y is not None: |
299 | 298 | if isinstance(spiral_vectors_y, jnp.ndarray): |
300 | 299 | spiral_vectors_y = (spiral_vectors_y,) |
301 | 300 | spiral_vectors = tuple( |
302 | 301 | jnp.array((sx, sy)) |
303 | | - for sx, sy in safe_zip(spiral_vectors, spiral_vectors_y) |
| 302 | + for sx, sy in zip(spiral_vectors, spiral_vectors_y, strict=True) |
304 | 303 | ) |
305 | 304 | else: |
306 | 305 | peps_tensors, unitcell = _map_tensors( |
|
0 commit comments