Skip to content

Commit 94d5aa7

Browse files
committed
Implement triangular CTMRG for maple leaf lattice
1 parent 807a8c6 commit 94d5aa7

File tree

4 files changed

+1001
-74
lines changed

4 files changed

+1001
-74
lines changed

varipeps/expectation/spiral_helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def apply_unitary(
5656
if isinstance(delta_r, jnp.ndarray):
5757
delta_r = (delta_r,) * len(apply_to_index)
5858

59+
if isinstance(q, jnp.ndarray):
60+
q = (q,) * len(apply_to_index)
61+
5962
if len(q) != len(apply_to_index) or len(q) != len(delta_r):
6063
raise ValueError("Length mismatch!")
6164

varipeps/expectation/triangular_one_site.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def calc_triangular_one_site(
3939
peps_tensor_objs,
4040
gates,
4141
):
42+
if isinstance(peps_tensors, jnp.ndarray):
43+
peps_tensors = (peps_tensors,)
44+
peps_tensor_objs = (peps_tensor_objs,)
45+
4246
real_result = all(jnp.allclose(g, g.T.conj()) for g in gates)
4347

4448
return _one_site_workhorse(

varipeps/expectation/triangular_two_sites.py

Lines changed: 33 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,16 @@
1616
from typing import Sequence, List, Tuple, Union, Optional
1717

1818

19-
@partial(jit, static_argnums=(3,))
20-
def calc_triangular_two_sites_horizontal(
21-
peps_tensors,
22-
peps_tensor_objs,
19+
def calc_triangular_two_sites_workhorse(
20+
density_matrix_left,
21+
density_matrix_right,
2322
gates,
2423
real_result=False,
2524
):
26-
density_matrix_left = apply_contraction_jitted(
27-
"triangular_ctmrg_two_site_expectation_horizontal_left",
28-
[peps_tensors[0]],
29-
[peps_tensor_objs[0]],
30-
[],
31-
)
32-
3325
density_matrix_left = density_matrix_left.reshape(
3426
density_matrix_left.shape[0], density_matrix_left.shape[1], -1
3527
)
3628

37-
density_matrix_right = apply_contraction_jitted(
38-
"triangular_ctmrg_two_site_expectation_horizontal_right",
39-
[peps_tensors[1]],
40-
[peps_tensor_objs[1]],
41-
[],
42-
)
43-
4429
density_matrix_right = density_matrix_right.reshape(
4530
-1, density_matrix_right.shape[-2], density_matrix_right.shape[-1]
4631
)
@@ -69,55 +54,55 @@ def calc_triangular_two_sites_horizontal(
6954

7055

7156
@partial(jit, static_argnums=(3,))
72-
def calc_triangular_two_sites_vertical(
57+
def calc_triangular_two_sites_horizontal(
7358
peps_tensors,
7459
peps_tensor_objs,
7560
gates,
7661
real_result=False,
7762
):
78-
density_matrix_top = apply_contraction_jitted(
79-
"triangular_ctmrg_two_site_expectation_vertical_top",
63+
density_matrix_left = apply_contraction_jitted(
64+
"triangular_ctmrg_two_site_expectation_horizontal_left",
8065
[peps_tensors[0]],
8166
[peps_tensor_objs[0]],
8267
[],
8368
)
8469

85-
density_matrix_top = density_matrix_top.reshape(
86-
density_matrix_top.shape[0], density_matrix_top.shape[1], -1
87-
)
88-
89-
density_matrix_bottom = apply_contraction_jitted(
90-
"triangular_ctmrg_two_site_expectation_vertical_bottom",
70+
density_matrix_right = apply_contraction_jitted(
71+
"triangular_ctmrg_two_site_expectation_horizontal_right",
9172
[peps_tensors[1]],
9273
[peps_tensor_objs[1]],
9374
[],
9475
)
9576

96-
density_matrix_bottom = density_matrix_bottom.reshape(
97-
-1, density_matrix_bottom.shape[-2], density_matrix_bottom.shape[-1]
77+
return calc_triangular_two_sites_workhorse(
78+
density_matrix_left, density_matrix_right, gates, real_result
9879
)
9980

100-
density_matrix = jnp.tensordot(
101-
density_matrix_top, density_matrix_bottom, ((2,), (0,))
102-
)
10381

104-
density_matrix = density_matrix.transpose(0, 2, 1, 3)
105-
density_matrix = density_matrix.reshape(
106-
density_matrix.shape[0] * density_matrix.shape[1],
107-
density_matrix.shape[2] * density_matrix.shape[3],
82+
@partial(jit, static_argnums=(3,))
83+
def calc_triangular_two_sites_vertical(
84+
peps_tensors,
85+
peps_tensor_objs,
86+
gates,
87+
real_result=False,
88+
):
89+
density_matrix_top = apply_contraction_jitted(
90+
"triangular_ctmrg_two_site_expectation_vertical_top",
91+
[peps_tensors[0]],
92+
[peps_tensor_objs[0]],
93+
[],
10894
)
10995

110-
norm = jnp.trace(density_matrix)
96+
density_matrix_bottom = apply_contraction_jitted(
97+
"triangular_ctmrg_two_site_expectation_vertical_bottom",
98+
[peps_tensors[1]],
99+
[peps_tensor_objs[1]],
100+
[],
101+
)
111102

112-
if real_result:
113-
return [
114-
jnp.real(jnp.tensordot(density_matrix, g, ((0, 1), (0, 1))) / norm)
115-
for g in gates
116-
]
117-
else:
118-
return [
119-
jnp.tensordot(density_matrix, g, ((0, 1), (0, 1))) / norm for g in gates
120-
]
103+
return calc_triangular_two_sites_workhorse(
104+
density_matrix_top, density_matrix_bottom, gates, real_result
105+
)
121106

122107

123108
@partial(jit, static_argnums=(3,))
@@ -134,43 +119,17 @@ def calc_triangular_two_sites_diagonal(
134119
[],
135120
)
136121

137-
density_matrix_top = density_matrix_top.reshape(
138-
density_matrix_top.shape[0], density_matrix_top.shape[1], -1
139-
)
140-
141122
density_matrix_bottom = apply_contraction_jitted(
142123
"triangular_ctmrg_two_site_expectation_diagonal_bottom",
143124
[peps_tensors[1]],
144125
[peps_tensor_objs[1]],
145126
[],
146127
)
147128

148-
density_matrix_bottom = density_matrix_bottom.reshape(
149-
-1, density_matrix_bottom.shape[-2], density_matrix_bottom.shape[-1]
150-
)
151-
152-
density_matrix = jnp.tensordot(
153-
density_matrix_top, density_matrix_bottom, ((2,), (0,))
129+
return calc_triangular_two_sites_workhorse(
130+
density_matrix_top, density_matrix_bottom, gates, real_result
154131
)
155132

156-
density_matrix = density_matrix.transpose(0, 2, 1, 3)
157-
density_matrix = density_matrix.reshape(
158-
density_matrix.shape[0] * density_matrix.shape[1],
159-
density_matrix.shape[2] * density_matrix.shape[3],
160-
)
161-
162-
norm = jnp.trace(density_matrix)
163-
164-
if real_result:
165-
return [
166-
jnp.real(jnp.tensordot(density_matrix, g, ((0, 1), (0, 1))) / norm)
167-
for g in gates
168-
]
169-
else:
170-
return [
171-
jnp.tensordot(density_matrix, g, ((0, 1), (0, 1))) / norm for g in gates
172-
]
173-
174133

175134
@dataclass
176135
class Triangular_Two_Sites_Expectation_Value(Expectation_Model):

0 commit comments

Comments
 (0)