@@ -277,6 +277,7 @@ def random_noise(a):
277277 best_value = jnp .inf
278278 best_tensors = None
279279 best_unitcell = None
280+ best_run = None
280281
281282 random_noise_retries = 0
282283
@@ -307,7 +308,10 @@ def random_noise(a):
307308 varipeps_config .line_search_initial_step_size
308309 )
309310 working_value : Union [float , jnp .ndarray ]
310- max_trunc_error_list = []
311+ max_trunc_error_list = {random_noise_retries : []}
312+ step_energies = {random_noise_retries : []}
313+ step_chi = {random_noise_retries : []}
314+ step_conv = {random_noise_retries : []}
311315
312316 if (
313317 varipeps_config .optimizer_preconverge_with_half_projectors
@@ -417,6 +421,7 @@ def random_noise(a):
417421 descent_dir = [- elem for elem in working_gradient ]
418422
419423 conv = jnp .linalg .norm (ravel_pytree (working_gradient )[0 ])
424+ step_conv [random_noise_retries ].append (conv )
420425
421426 try :
422427 (
@@ -460,6 +465,7 @@ def random_noise(a):
460465 best_value = working_value
461466 best_tensors = working_tensors
462467 best_unitcell = working_unitcell
468+ best_run = random_noise_retries
463469
464470 if isinstance (input_tensors , PEPS_Unit_Cell ) or (
465471 isinstance (input_tensors , collections .abc .Sequence )
@@ -497,12 +503,22 @@ def random_noise(a):
497503 signal_reset_descent_dir = True
498504 count = - 1
499505 random_noise_retries += 1
506+
507+ step_energies [random_noise_retries ] = []
508+ step_chi [random_noise_retries ] = []
509+ step_conv [random_noise_retries ] = []
510+ max_trunc_error_list [random_noise_retries ] = []
511+
500512 pbar .reset ()
501513 pbar .refresh ()
502514 else :
503515 conv = 0
504-
505- max_trunc_error_list .append (max_trunc_error )
516+ else :
517+ max_trunc_error_list [random_noise_retries ].append (max_trunc_error )
518+ step_energies [random_noise_retries ].append (working_value )
519+ step_chi [random_noise_retries ].append (
520+ working_unitcell .get_unique_tensors ()[0 ].chi
521+ )
506522
507523 if conv < varipeps_config .optimizer_convergence_eps :
508524 working_value , (
@@ -517,7 +533,8 @@ def random_noise(a):
517533 enforce_elementwise_convergence = varipeps_config .ad_use_custom_vjp ,
518534 )
519535 varipeps_global_state .ctmrg_projector_method = None
520- max_trunc_error_list [- 1 ] = max_trunc_error
536+ max_trunc_error_list [random_noise_retries ][- 1 ] = max_trunc_error
537+ step_energies [random_noise_retries ][- 1 ] = working_value
521538 break
522539
523540 if (
@@ -561,7 +578,16 @@ def random_noise(a):
561578
562579 if count % varipeps_config .optimizer_autosave_step_count == 0 :
563580 auxiliary_data = {
564- "max_trunc_error_list" : max_trunc_error_list ,
581+ "max_trunc_error_list" : tuple (
582+ max_trunc_error_list [k ]
583+ for k in sorted (max_trunc_error_list .keys ())
584+ ),
585+ "step_energies" : tuple (
586+ step_energies [k ] for k in sorted (step_energies .keys ())
587+ ),
588+ "step_chi" : tuple (step_chi [k ] for k in sorted (step_chi .keys ())),
589+ "step_conv" : tuple (step_conv [k ] for k in sorted (step_conv .keys ())),
590+ "best_run" : best_run if best_run is not None else 0 ,
565591 }
566592
567593 if spiral_indices is not None :
@@ -589,6 +615,7 @@ def random_noise(a):
589615 best_value = working_value
590616 best_tensors = working_tensors
591617 best_unitcell = working_unitcell
618+ best_run = random_noise_retries
592619
593620 print (f"Best energy result found: { best_value } " )
594621
@@ -599,4 +626,8 @@ def random_noise(a):
599626 unitcell = best_unitcell ,
600627 nit = count ,
601628 max_trunc_error_list = max_trunc_error_list ,
629+ step_energies = step_energies ,
630+ step_chi = step_chi ,
631+ step_conv = step_conv ,
632+ best_run = best_run ,
602633 )
0 commit comments