@@ -186,7 +186,7 @@ def autosave_function(
186186 filename : PathLike ,
187187 tensors : jnp .ndarray ,
188188 unitcell : PEPS_Unit_Cell ,
189- counter : Optional [int ] = None ,
189+ counter : Optional [Union [ int , str ] ] = None ,
190190 auxiliary_data : Optional [Dict [str , Any ]] = None ,
191191) -> None :
192192 if counter is not None :
@@ -197,6 +197,52 @@ def autosave_function(
197197 unitcell .save_to_file (filename , auxiliary_data = auxiliary_data )
198198
199199
200+ def _autosave_wrapper (
201+ autosave_func ,
202+ autosave_filename ,
203+ working_tensors ,
204+ working_unitcell ,
205+ working_value ,
206+ counter ,
207+ best_run ,
208+ max_trunc_error_list ,
209+ step_energies ,
210+ step_chi ,
211+ step_conv ,
212+ spiral_indices ,
213+ additional_input ,
214+ ):
215+ auxiliary_data = {
216+ "best_run" : best_run if best_run is not None else 0 ,
217+ "current_energy" : working_value ,
218+ }
219+
220+ for k in sorted (max_trunc_error_list .keys ()):
221+ auxiliary_data [f"max_trunc_error_list_{ k :d} " ] = max_trunc_error_list [k ]
222+ auxiliary_data [f"step_energies_{ k :d} " ] = step_energies [k ]
223+ auxiliary_data [f"step_chi_{ k :d} " ] = step_chi [k ]
224+ auxiliary_data [f"step_conv_{ k :d} " ] = step_conv [k ]
225+
226+ if spiral_indices is not None :
227+ for spiral_i in spiral_indices :
228+ auxiliary_data [f"spiral_vector_{ spiral_i :d} " ] = working_tensors [spiral_i ]
229+ elif additional_input .get ("spiral_vectors" ) is not None :
230+ add_input_spiral = additional_input .get ("spiral_vectors" )
231+ if isinstance (add_input_spiral , jnp .ndarray ):
232+ add_input_spiral = (add_input_spiral ,)
233+ for spiral_i , elem in enumerate (add_input_spiral ):
234+ spiral_i += 1
235+ auxiliary_data [f"spiral_vector_{ spiral_i :d} " ] = elem
236+
237+ autosave_func (
238+ autosave_filename ,
239+ working_tensors ,
240+ working_unitcell ,
241+ counter = counter ,
242+ auxiliary_data = auxiliary_data ,
243+ )
244+
245+
200246def optimize_peps_network (
201247 input_tensors : Union [PEPS_Unit_Cell , Sequence [jnp .ndarray ]],
202248 expectation_func : Expectation_Model ,
@@ -467,6 +513,22 @@ def random_noise(a):
467513 best_unitcell = working_unitcell
468514 best_run = random_noise_retries
469515
516+ _autosave_wrapper (
517+ autosave_func ,
518+ autosave_filename ,
519+ working_tensors ,
520+ working_unitcell ,
521+ working_value ,
522+ "best" ,
523+ best_run ,
524+ max_trunc_error_list ,
525+ step_energies ,
526+ step_chi ,
527+ step_conv ,
528+ spiral_indices ,
529+ additional_input ,
530+ )
531+
470532 if isinstance (input_tensors , PEPS_Unit_Cell ) or (
471533 isinstance (input_tensors , collections .abc .Sequence )
472534 and isinstance (input_tensors [0 ], PEPS_Unit_Cell )
@@ -577,37 +639,20 @@ def random_noise(a):
577639 pbar .refresh ()
578640
579641 if count % varipeps_config .optimizer_autosave_step_count == 0 :
580- auxiliary_data = {
581- "best_run" : best_run if best_run is not None else 0 ,
582- }
583-
584- for k in sorted (max_trunc_error_list .keys ()):
585- auxiliary_data [f"max_trunc_error_list_{ k :d} " ] = (
586- max_trunc_error_list [k ]
587- )
588- auxiliary_data [f"step_energies_{ k :d} " ] = step_energies [k ]
589- auxiliary_data [f"step_chi_{ k :d} " ] = step_chi [k ]
590- auxiliary_data [f"step_conv_{ k :d} " ] = step_conv [k ]
591-
592- if spiral_indices is not None :
593- for spiral_i in spiral_indices :
594- auxiliary_data [f"spiral_vector_{ spiral_i :d} " ] = working_tensors [
595- spiral_i
596- ]
597- elif additional_input .get ("spiral_vectors" ) is not None :
598- add_input_spiral = additional_input .get ("spiral_vectors" )
599- if isinstance (add_input_spiral , jnp .ndarray ):
600- add_input_spiral = (add_input_spiral ,)
601- for spiral_i , elem in enumerate (add_input_spiral ):
602- spiral_i += 1
603- auxiliary_data [f"spiral_vector_{ spiral_i :d} " ] = elem
604-
605- autosave_func (
642+ _autosave_wrapper (
643+ autosave_func ,
606644 autosave_filename ,
607645 working_tensors ,
608646 working_unitcell ,
609- counter = random_noise_retries ,
610- auxiliary_data = auxiliary_data ,
647+ working_value ,
648+ random_noise_retries ,
649+ best_run ,
650+ max_trunc_error_list ,
651+ step_energies ,
652+ step_chi ,
653+ step_conv ,
654+ spiral_indices ,
655+ additional_input ,
611656 )
612657
613658 if working_value < best_value :
@@ -616,6 +661,22 @@ def random_noise(a):
616661 best_unitcell = working_unitcell
617662 best_run = random_noise_retries
618663
664+ _autosave_wrapper (
665+ autosave_func ,
666+ autosave_filename ,
667+ working_tensors ,
668+ working_unitcell ,
669+ working_value ,
670+ "best" ,
671+ best_run ,
672+ max_trunc_error_list ,
673+ step_energies ,
674+ step_chi ,
675+ step_conv ,
676+ spiral_indices ,
677+ additional_input ,
678+ )
679+
619680 print (f"Best energy result found: { best_value } " )
620681
621682 return OptimizeResult (
0 commit comments