88from jax import jit
99import jax .util
1010
11+ from varipeps import varipeps_config
1112import varipeps .config
1213from varipeps .peps import PEPS_Tensor , PEPS_Unit_Cell
1314from varipeps .contractions import apply_contraction , Definitions
1415from varipeps .expectation .model import Expectation_Model
1516from varipeps .expectation .one_site import calc_one_site_multi_gates
1617from varipeps .expectation .two_sites import _two_site_workhorse
18+ from varipeps .expectation .spiral_helpers import apply_unitary
1719from varipeps .typing import Tensor
1820from varipeps .mapping import Map_To_PEPS_Model
1921from varipeps .utils .random import PEPS_Random_Number_Generator
@@ -111,8 +113,18 @@ def __post_init__(self) -> None:
111113 self ._y_tuple = tuple (self .y_gates )
112114 self ._z_tuple = tuple (self .z_gates )
113115
116+ self ._result_type = (
117+ jnp .float64
118+ if all (jnp .allclose (g , g .T .conj ()) for g in self .x_gates )
119+ and all (jnp .allclose (g , g .T .conj ()) for g in self .y_gates )
120+ and all (jnp .allclose (g , g .T .conj ()) for g in self .z_gates )
121+ else jnp .complex128
122+ )
123+
114124 if self .is_spiral_peps :
115- raise NotImplementedError
125+ self ._spiral_D , self ._spiral_sigma = jnp .linalg .eigh (
126+ self .spiral_unitary_operator
127+ )
116128
117129 def __call__ (
118130 self ,
@@ -124,15 +136,8 @@ def __call__(
124136 only_unique : bool = True ,
125137 return_single_gate_results : bool = False ,
126138 ) -> Union [jnp .ndarray , List [jnp .ndarray ]]:
127- result_type = (
128- jnp .float64
129- if all (jnp .allclose (g , jnp .real (g )) for g in self .x_gates )
130- and all (jnp .allclose (g , jnp .real (g )) for g in self .y_gates )
131- and all (jnp .allclose (g , jnp .real (g )) for g in self .z_gates )
132- else jnp .complex128
133- )
134139 result = [
135- jnp .array (0 , dtype = result_type )
140+ jnp .array (0 , dtype = self . _result_type )
136141 for _ in range (
137142 max (
138143 len (self .x_gates ),
@@ -142,6 +147,44 @@ def __call__(
142147 )
143148 ]
144149
150+ if self .is_spiral_peps :
151+ if isinstance (spiral_vectors , jnp .ndarray ):
152+ spiral_vectors = (spiral_vectors ,)
153+ if len (spiral_vectors ) != 1 :
154+ raise ValueError ("Length mismatch for spiral vectors!" )
155+
156+ working_h_gates = tuple (
157+ apply_unitary (
158+ h ,
159+ jnp .array ((0 , 1 )),
160+ spiral_vectors ,
161+ self ._spiral_D ,
162+ self ._spiral_sigma ,
163+ self .real_d ,
164+ 2 ,
165+ (1 ,),
166+ varipeps_config .spiral_wavevector_type ,
167+ )
168+ for h in self ._y_tuple
169+ )
170+ working_v_gates = tuple (
171+ apply_unitary (
172+ v ,
173+ jnp .array ((1 , 0 )),
174+ spiral_vectors ,
175+ self ._spiral_D ,
176+ self ._spiral_sigma ,
177+ self .real_d ,
178+ 2 ,
179+ (1 ,),
180+ varipeps_config .spiral_wavevector_type ,
181+ )
182+ for v in self ._z_tuple
183+ )
184+ else :
185+ working_h_gates = self ._y_tuple
186+ working_v_gates = self ._z_tuple
187+
145188 for x , iter_rows in unitcell .iter_all_rows (only_unique = only_unique ):
146189 for y , view in iter_rows :
147190 # On site x term
@@ -196,8 +239,8 @@ def __call__(
196239 step_result_y = _two_site_workhorse (
197240 density_matrix_left ,
198241 density_matrix_right ,
199- self . _y_tuple ,
200- result_type is jnp .float64 ,
242+ working_h_gates ,
243+ self . _result_type is jnp .float64 ,
201244 )
202245
203246 for sr_i , sr in enumerate (step_result_y ):
@@ -241,8 +284,8 @@ def __call__(
241284 step_result_z = _two_site_workhorse (
242285 density_matrix_top ,
243286 density_matrix_bottom ,
244- self . _z_tuple ,
245- result_type is jnp .float64 ,
287+ working_v_gates ,
288+ self . _result_type is jnp .float64 ,
246289 )
247290
248291 for sr_i , sr in enumerate (step_result_z ):
0 commit comments