@@ -82,6 +82,7 @@ class Triangular_Expectation_Value(Expectation_Model):
8282 """
8383
8484 nearest_neighbor_gates : Sequence [jnp .ndarray ]
85+ real_d : int
8586 normalization_factor : int = 1
8687
8788 is_spiral_peps : bool = False
@@ -90,6 +91,22 @@ class Triangular_Expectation_Value(Expectation_Model):
9091 def __post_init__ (self ) -> None :
9192 if isinstance (self .nearest_neighbor_gates , jnp .ndarray ):
9293 self .nearest_neighbor_gates = (self .nearest_neighbor_gates ,)
94+ else :
95+ self .nearest_neighbor_gates = tuple (self .nearest_neighbor_gates )
96+
97+ self ._result_type = (
98+ jnp .float64
99+ if all (
100+ jnp .allclose (g , g .T .conj ())
101+ for g in self .nearest_neighbor_gates
102+ )
103+ else jnp .complex128
104+ )
105+
106+ if self .is_spiral_peps :
107+ self ._spiral_D , self ._spiral_sigma = jnp .linalg .eigh (
108+ self .spiral_unitary_operator
109+ )
93110
94111 def __call__ (
95112 self ,
@@ -101,32 +118,84 @@ def __call__(
101118 only_unique : bool = True ,
102119 return_single_gate_results : bool = False ,
103120 ) -> Union [jnp .ndarray , List [jnp .ndarray ]]:
104- result_type = (
105- jnp .float64
106- if all (jnp .allclose (g , jnp .real (g )) for g in self .nearest_neighbor_gates )
107- else jnp .complex128
108- )
109121 result = [
110- jnp .array (0 , dtype = result_type )
122+ jnp .array (0 , dtype = self . _result_type )
111123 for _ in range (len (self .nearest_neighbor_gates ))
112124 ]
113125
126+ if self .is_spiral_peps :
127+ if (
128+ isinstance (spiral_vectors , collections .abc .Sequence )
129+ and len (spiral_vectors ) == 1
130+ ):
131+ spiral_vectors = spiral_vectors [0 ]
132+
133+ if not isinstance (spiral_vectors , jnp .ndarray ):
134+ raise ValueError ("Expect spiral vector as single jax.numpy array." )
135+
136+ working_h_gates = tuple (
137+ apply_unitary (
138+ h ,
139+ jnp .array ((0 , 1 )),
140+ (spiral_vectors ,),
141+ self ._spiral_D ,
142+ self ._spiral_sigma ,
143+ self .real_d ,
144+ 2 ,
145+ (1 ,),
146+ varipeps_config .spiral_wavevector_type ,
147+ )
148+ for h in self .nearest_neighbor_gates
149+ )
150+ working_v_gates = tuple (
151+ apply_unitary (
152+ v ,
153+ jnp .array ((1 , 0 )),
154+ (spiral_vectors ,),
155+ self ._spiral_D ,
156+ self ._spiral_sigma ,
157+ self .real_d ,
158+ 2 ,
159+ (1 ,),
160+ varipeps_config .spiral_wavevector_type ,
161+ )
162+ for v in self .nearest_neighbor_gates
163+ )
164+ working_d_gates = tuple (
165+ apply_unitary (
166+ d ,
167+ jnp .array ((1 , 1 )),
168+ (spiral_vectors ,),
169+ self ._spiral_D ,
170+ self ._spiral_sigma ,
171+ self .real_d ,
172+ 2 ,
173+ (1 ,),
174+ varipeps_config .spiral_wavevector_type ,
175+ )
176+ for d in self .nearest_neighbor_gates
177+ )
178+ else :
179+ working_h_gates = self .nearest_neighbor_gates
180+ working_v_gates = self .nearest_neighbor_gates
181+ working_d_gates = self .nearest_neighbor_gates
182+
114183 for x , iter_rows in unitcell .iter_all_rows (only_unique = only_unique ):
115184 for y , view in iter_rows :
116185 x_tensors_i = view .get_indices ((slice (0 , 2 , None ), 0 ))
117186 x_tensors = [peps_tensors [i ] for j in x_tensors_i for i in j ]
118187 x_tensor_objs = [t for tl in view [:2 , 0 ] for t in tl ]
119188
120189 step_result_x = calc_two_sites_vertical_multiple_gates (
121- x_tensors , x_tensor_objs , self . nearest_neighbor_gates
190+ x_tensors , x_tensor_objs , working_v_gates
122191 )
123192
124193 y_tensors_i = view .get_indices ((0 , slice (0 , 2 , None )))
125194 y_tensors = [peps_tensors [i ] for j in y_tensors_i for i in j ]
126195 y_tensor_objs = [t for tl in view [0 , :2 ] for t in tl ]
127196
128197 step_result_y = calc_two_sites_horizontal_multiple_gates (
129- y_tensors , y_tensor_objs , self . nearest_neighbor_gates
198+ y_tensors , y_tensor_objs , working_h_gates
130199 )
131200
132201 diagonal_tensors_i = view .get_indices (
@@ -141,7 +210,7 @@ def __call__(
141210 calc_two_sites_diagonal_top_left_bottom_right_multiple_gates (
142211 diagonal_tensors ,
143212 diagonal_tensor_objs ,
144- self . nearest_neighbor_gates ,
213+ working_d_gates ,
145214 )
146215 )
147216
0 commit comments