1111import jax .numpy as jnp
1212from jax .lax import scan
1313from jax .flatten_util import ravel_pytree
14+ from jax .util import safe_zip
1415
1516from varipeps import varipeps_config , varipeps_global_state
1617from varipeps .config import Optimizing_Methods
@@ -213,7 +214,7 @@ def _autosave_wrapper(
213214 additional_input ,
214215):
215216 auxiliary_data = {
216- "best_run" : best_run if best_run is not None else 0 ,
217+ "best_run" : jnp . array ( best_run if best_run is not None else 0 ) ,
217218 "current_energy" : working_value ,
218219 }
219220
@@ -223,16 +224,43 @@ def _autosave_wrapper(
223224 auxiliary_data [f"step_chi_{ k :d} " ] = step_chi [k ]
224225 auxiliary_data [f"step_conv_{ k :d} " ] = step_conv [k ]
225226
227+ spiral_vectors = None
226228 if spiral_indices is not None :
227- for spiral_i in spiral_indices :
228- auxiliary_data [f"spiral_vector_{ spiral_i :d} " ] = working_tensors [spiral_i ]
229+ spiral_vectors = [working_tensors [spiral_i ] for spiral_i in spiral_indices ]
230+
231+ if any (i .size == 1 for i in spiral_vectors ):
232+ spiral_vectors_x = additional_input .get ("spiral_vectors_x" )
233+ spiral_vectors_y = additional_input .get ("spiral_vectors_y" )
234+ if spiral_vectors_x is not None :
235+ if isinstance (spiral_vectors_x , jnp .ndarray ):
236+ spiral_vectors_x = (spiral_vectors_x ,)
237+ spiral_vectors = tuple (
238+ jnp .array ((sx , sy ))
239+ for sx , sy in safe_zip (spiral_vectors_x , spiral_vectors )
240+ )
241+ elif spiral_vectors_y is not None :
242+ if isinstance (spiral_vectors_y , jnp .ndarray ):
243+ spiral_vectors_y = (spiral_vectors_y ,)
244+ spiral_vectors = tuple (
245+ jnp .array ((sx , sy ))
246+ for sx , sy in safe_zip (spiral_vectors , spiral_vectors_y )
247+ )
229248 elif additional_input .get ("spiral_vectors" ) is not None :
230- add_input_spiral = additional_input .get ("spiral_vectors" )
231- if isinstance (add_input_spiral , jnp .ndarray ):
232- add_input_spiral = (add_input_spiral ,)
233- for spiral_i , elem in enumerate (add_input_spiral ):
234- spiral_i += 1
235- auxiliary_data [f"spiral_vector_{ spiral_i :d} " ] = elem
249+ spiral_vectors = additional_input .get ("spiral_vectors" )
250+ if isinstance (spiral_vectors , jnp .ndarray ):
251+ spiral_vectors = (spiral_vectors ,)
252+
253+ if spiral_vectors is not None :
254+ spiral_vectors = [
255+ e if e .size == 2 else jnp .array ((e , e )).reshape (2 ) for e in spiral_vectors
256+ ]
257+
258+ if len (spiral_vectors ) == 1 :
259+ auxiliary_data ["spiral_vector" ] = spiral_vectors [0 ]
260+ else :
261+ for spiral_i , vec in enumerate (spiral_vectors ):
262+ spiral_i += 1
263+ auxiliary_data [f"spiral_vector_{ spiral_i :d} " ] = vec
236264
237265 autosave_func (
238266 autosave_filename ,
0 commit comments