@@ -234,6 +234,24 @@ def calc_ctmrg_expectation_custom(
234234 peps_tensors , unitcell , spiral_vectors = _map_tensors (
235235 input_tensors , unitcell , convert_to_unitcell_func , True
236236 )
237+
238+ if any (i .size == 1 for i in spiral_vectors ):
239+ spiral_vectors_x = additional_input .get ("spiral_vectors_x" )
240+ spiral_vectors_y = additional_input .get ("spiral_vectors_y" )
241+ if spiral_vectors_x is not None :
242+ if isinstance (spiral_vectors_x , jnp .ndarray ):
243+ spiral_vectors_x = (spiral_vectors_x ,)
244+ spiral_vectors = tuple (
245+ jnp .array ((sx , sy ))
246+ for sx , sy in safe_zip (spiral_vectors_x , spiral_vectors )
247+ )
248+ elif spiral_vectors_y is not None :
249+ if isinstance (spiral_vectors_y , jnp .ndarray ):
250+ spiral_vectors_y = (spiral_vectors_y ,)
251+ spiral_vectors = tuple (
252+ jnp .array ((sx , sy ))
253+ for sx , sy in safe_zip (spiral_vectors , spiral_vectors_y )
254+ )
237255 else :
238256 peps_tensors , unitcell = _map_tensors (
239257 input_tensors , unitcell , convert_to_unitcell_func , False
0 commit comments