1616from typing import Sequence , List , Tuple , Union , Optional
1717
1818
19- @partial (jit , static_argnums = (3 ,))
20- def calc_triangular_two_sites_horizontal (
21- peps_tensors ,
22- peps_tensor_objs ,
19+ def calc_triangular_two_sites_workhorse (
20+ density_matrix_left ,
21+ density_matrix_right ,
2322 gates ,
2423 real_result = False ,
2524):
26- density_matrix_left = apply_contraction_jitted (
27- "triangular_ctmrg_two_site_expectation_horizontal_left" ,
28- [peps_tensors [0 ]],
29- [peps_tensor_objs [0 ]],
30- [],
31- )
32-
3325 density_matrix_left = density_matrix_left .reshape (
3426 density_matrix_left .shape [0 ], density_matrix_left .shape [1 ], - 1
3527 )
3628
37- density_matrix_right = apply_contraction_jitted (
38- "triangular_ctmrg_two_site_expectation_horizontal_right" ,
39- [peps_tensors [1 ]],
40- [peps_tensor_objs [1 ]],
41- [],
42- )
43-
4429 density_matrix_right = density_matrix_right .reshape (
4530 - 1 , density_matrix_right .shape [- 2 ], density_matrix_right .shape [- 1 ]
4631 )
@@ -69,55 +54,55 @@ def calc_triangular_two_sites_horizontal(
6954
7055
7156@partial (jit , static_argnums = (3 ,))
72- def calc_triangular_two_sites_vertical (
57+ def calc_triangular_two_sites_horizontal (
7358 peps_tensors ,
7459 peps_tensor_objs ,
7560 gates ,
7661 real_result = False ,
7762):
78- density_matrix_top = apply_contraction_jitted (
79- "triangular_ctmrg_two_site_expectation_vertical_top " ,
63+ density_matrix_left = apply_contraction_jitted (
64+ "triangular_ctmrg_two_site_expectation_horizontal_left " ,
8065 [peps_tensors [0 ]],
8166 [peps_tensor_objs [0 ]],
8267 [],
8368 )
8469
85- density_matrix_top = density_matrix_top .reshape (
86- density_matrix_top .shape [0 ], density_matrix_top .shape [1 ], - 1
87- )
88-
89- density_matrix_bottom = apply_contraction_jitted (
90- "triangular_ctmrg_two_site_expectation_vertical_bottom" ,
70+ density_matrix_right = apply_contraction_jitted (
71+ "triangular_ctmrg_two_site_expectation_horizontal_right" ,
9172 [peps_tensors [1 ]],
9273 [peps_tensor_objs [1 ]],
9374 [],
9475 )
9576
96- density_matrix_bottom = density_matrix_bottom . reshape (
97- - 1 , density_matrix_bottom . shape [ - 2 ], density_matrix_bottom . shape [ - 1 ]
77+ return calc_triangular_two_sites_workhorse (
78+ density_matrix_left , density_matrix_right , gates , real_result
9879 )
9980
100- density_matrix = jnp .tensordot (
101- density_matrix_top , density_matrix_bottom , ((2 ,), (0 ,))
102- )
10381
104- density_matrix = density_matrix .transpose (0 , 2 , 1 , 3 )
105- density_matrix = density_matrix .reshape (
106- density_matrix .shape [0 ] * density_matrix .shape [1 ],
107- density_matrix .shape [2 ] * density_matrix .shape [3 ],
82+ @partial (jit , static_argnums = (3 ,))
83+ def calc_triangular_two_sites_vertical (
84+ peps_tensors ,
85+ peps_tensor_objs ,
86+ gates ,
87+ real_result = False ,
88+ ):
89+ density_matrix_top = apply_contraction_jitted (
90+ "triangular_ctmrg_two_site_expectation_vertical_top" ,
91+ [peps_tensors [0 ]],
92+ [peps_tensor_objs [0 ]],
93+ [],
10894 )
10995
110- norm = jnp .trace (density_matrix )
96+ density_matrix_bottom = apply_contraction_jitted (
97+ "triangular_ctmrg_two_site_expectation_vertical_bottom" ,
98+ [peps_tensors [1 ]],
99+ [peps_tensor_objs [1 ]],
100+ [],
101+ )
111102
112- if real_result :
113- return [
114- jnp .real (jnp .tensordot (density_matrix , g , ((0 , 1 ), (0 , 1 ))) / norm )
115- for g in gates
116- ]
117- else :
118- return [
119- jnp .tensordot (density_matrix , g , ((0 , 1 ), (0 , 1 ))) / norm for g in gates
120- ]
103+ return calc_triangular_two_sites_workhorse (
104+ density_matrix_top , density_matrix_bottom , gates , real_result
105+ )
121106
122107
123108@partial (jit , static_argnums = (3 ,))
@@ -134,43 +119,17 @@ def calc_triangular_two_sites_diagonal(
134119 [],
135120 )
136121
137- density_matrix_top = density_matrix_top .reshape (
138- density_matrix_top .shape [0 ], density_matrix_top .shape [1 ], - 1
139- )
140-
141122 density_matrix_bottom = apply_contraction_jitted (
142123 "triangular_ctmrg_two_site_expectation_diagonal_bottom" ,
143124 [peps_tensors [1 ]],
144125 [peps_tensor_objs [1 ]],
145126 [],
146127 )
147128
148- density_matrix_bottom = density_matrix_bottom .reshape (
149- - 1 , density_matrix_bottom .shape [- 2 ], density_matrix_bottom .shape [- 1 ]
150- )
151-
152- density_matrix = jnp .tensordot (
153- density_matrix_top , density_matrix_bottom , ((2 ,), (0 ,))
129+ return calc_triangular_two_sites_workhorse (
130+ density_matrix_top , density_matrix_bottom , gates , real_result
154131 )
155132
156- density_matrix = density_matrix .transpose (0 , 2 , 1 , 3 )
157- density_matrix = density_matrix .reshape (
158- density_matrix .shape [0 ] * density_matrix .shape [1 ],
159- density_matrix .shape [2 ] * density_matrix .shape [3 ],
160- )
161-
162- norm = jnp .trace (density_matrix )
163-
164- if real_result :
165- return [
166- jnp .real (jnp .tensordot (density_matrix , g , ((0 , 1 ), (0 , 1 ))) / norm )
167- for g in gates
168- ]
169- else :
170- return [
171- jnp .tensordot (density_matrix , g , ((0 , 1 ), (0 , 1 ))) / norm for g in gates
172- ]
173-
174133
175134@dataclass
176135class Triangular_Two_Sites_Expectation_Value (Expectation_Model ):
0 commit comments