@@ -99,6 +99,24 @@ def calc_ctmrg_expectation(
9999 peps_tensors , unitcell , spiral_vectors = _map_tensors (
100100 input_tensors , unitcell , convert_to_unitcell_func , True
101101 )
102+
103+ if any (i .size == 1 for i in spiral_vectors ):
104+ spiral_vectors_x = additional_input .get ("spiral_vectors_x" )
105+ spiral_vectors_y = additional_input .get ("spiral_vectors_y" )
106+ if spiral_vectors_x is not None :
107+ if isinstance (spiral_vectors_x , jnp .ndarray ):
108+ spiral_vectors_x = (spiral_vectors_x ,)
109+ spiral_vectors = tuple (
110+ jnp .array ((sx , sy ))
111+ for sx , sy in safe_zip (spiral_vectors_x , spiral_vectors )
112+ )
113+ elif spiral_vectors_y is not None :
114+ if isinstance (spiral_vectors_y , jnp .ndarray ):
115+ spiral_vectors_y = (spiral_vectors_y ,)
116+ spiral_vectors = tuple (
117+ jnp .array ((sx , sy ))
118+ for sx , sy in safe_zip (spiral_vectors , spiral_vectors_y )
119+ )
102120 else :
103121 peps_tensors , unitcell = _map_tensors (
104122 input_tensors , unitcell , convert_to_unitcell_func , False
@@ -175,6 +193,24 @@ def calc_preconverged_ctmrg_value_and_grad(
175193 peps_tensors , unitcell , spiral_vectors = _map_tensors (
176194 input_tensors , unitcell , convert_to_unitcell_func , True
177195 )
196+
197+ if any (i .size == 1 for i in spiral_vectors ):
198+ spiral_vectors_x = additional_input .get ("spiral_vectors_x" )
199+ spiral_vectors_y = additional_input .get ("spiral_vectors_y" )
200+ if spiral_vectors_x is not None :
201+ if isinstance (spiral_vectors_x , jnp .ndarray ):
202+ spiral_vectors_x = (spiral_vectors_x ,)
203+ spiral_vectors = tuple (
204+ jnp .array ((sx , sy ))
205+ for sx , sy in safe_zip (spiral_vectors_x , spiral_vectors )
206+ )
207+ elif spiral_vectors_y is not None :
208+ if isinstance (spiral_vectors_y , jnp .ndarray ):
209+ spiral_vectors_y = (spiral_vectors_y ,)
210+ spiral_vectors = tuple (
211+ jnp .array ((sx , sy ))
212+ for sx , sy in safe_zip (spiral_vectors , spiral_vectors_y )
213+ )
178214 else :
179215 peps_tensors , unitcell = _map_tensors (
180216 input_tensors , unitcell , convert_to_unitcell_func , False
0 commit comments