Skip to content

Commit 35c2b0c

Browse files
committed
Replace safe_zip by built-in zip function for triangular CTMRG
1 parent 066c37c commit 35c2b0c

File tree

4 files changed

+8
-4
lines changed

4 files changed

+8
-4
lines changed

varipeps/expectation/triangular_next_nearest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,13 +544,14 @@ def __call__(
544544
)
545545

546546
for sr_i, (sr_h, sr_v, sr_d, sr_np, sr_p2p, sr_2pp) in enumerate(
547-
jax.util.safe_zip(
547+
zip(
548548
step_result_horizontal,
549549
step_result_vertical,
550550
step_result_diagonal,
551551
step_result_nn_neg_pos,
552552
step_result_nn_pos_2pos,
553553
step_result_nn_2pos_pos,
554+
strict=True,
554555
)
555556
):
556557
result[sr_i] += sr_h + sr_v + sr_d + sr_np + sr_p2p + sr_2pp

varipeps/expectation/triangular_two_sites.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,11 @@ def __call__(
281281
)
282282

283283
for sr_i, (sr_h, sr_v, sr_d) in enumerate(
284-
jax.util.safe_zip(
284+
zip(
285285
step_result_horizontal,
286286
step_result_vertical,
287287
step_result_diagonal,
288+
strict=True,
288289
)
289290
):
290291
result[sr_i] += sr_h + sr_v + sr_d

varipeps/mapping/kagome.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2007,11 +2007,12 @@ def __call__(
20072007
)
20082008

20092009
for sr_i, (sr_o, sr_h, sr_v, sr_d) in enumerate(
2010-
jax.util.safe_zip(
2010+
zip(
20112011
step_result_onsite[: len(self.up_nearest_gates)],
20122012
step_result_horizontal[: len(self.up_nearest_gates)],
20132013
step_result_vertical[: len(self.up_nearest_gates)],
20142014
step_result_diagonal[: len(self.up_nearest_gates)],
2015+
strict=True,
20152016
)
20162017
):
20172018
result[sr_i] += sr_o + sr_h + sr_v + sr_d

varipeps/mapping/maple_leaf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2238,11 +2238,12 @@ def __call__(
22382238
)
22392239

22402240
for sr_i, (sr_o, sr_h, sr_v, sr_d) in enumerate(
2241-
jax.util.safe_zip(
2241+
zip(
22422242
step_result_onsite[: len(self.green_gates)],
22432243
step_result_horizontal[: len(self.green_gates)],
22442244
step_result_vertical[: len(self.green_gates)],
22452245
step_result_diagonal[: len(self.green_gates)],
2246+
strict=True,
22462247
)
22472248
):
22482249
result[sr_i] += sr_o + sr_h + sr_v + sr_d

0 commit comments

Comments
 (0)