@@ -400,15 +400,70 @@ def random_noise(a):
400400 except (CTMRGNotConvergedError , CTMRGGradientNotConvergedError ) as e :
401401 varipeps_global_state .ctmrg_projector_method = None
402402
403- return OptimizeResult (
404- success = False ,
405- message = str (type (e )),
406- x = working_tensors ,
407- fun = working_value ,
408- unitcell = working_unitcell ,
409- nit = count ,
410- max_trunc_error_list = max_trunc_error_list ,
411- )
403+ if random_noise_retries == 0 :
404+ return OptimizeResult (
405+ success = False ,
406+ message = str (type (e )),
407+ x = working_tensors ,
408+ fun = working_value ,
409+ unitcell = working_unitcell ,
410+ nit = count ,
411+ max_trunc_error_list = max_trunc_error_list ,
412+ step_energies = step_energies ,
413+ step_chi = step_chi ,
414+ step_conv = step_conv ,
415+ best_run = 0 ,
416+ )
417+ elif (
418+ random_noise_retries
419+ >= varipeps_config .optimizer_random_noise_max_retries
420+ ):
421+ working_value = jnp .inf
422+ break
423+ else :
424+ if isinstance (input_tensors , PEPS_Unit_Cell ) or (
425+ isinstance (input_tensors , collections .abc .Sequence )
426+ and isinstance (input_tensors [0 ], PEPS_Unit_Cell )
427+ ):
428+ working_tensors = (
429+ cast (
430+ List [jnp .ndarray ],
431+ [i .tensor for i in best_unitcell .get_unique_tensors ()],
432+ )
433+ + best_tensors [best_unitcell .get_len_unique_tensors () :]
434+ )
435+
436+ working_tensors = [random_noise (i ) for i in working_tensors ]
437+
438+ working_tensors_obj = [
439+ e .replace_tensor (working_tensors [i ])
440+ for i , e in enumerate (best_unitcell .get_unique_tensors ())
441+ ]
442+
443+ working_unitcell = best_unitcell .replace_unique_tensors (
444+ working_tensors_obj
445+ )
446+ else :
447+ working_tensors = [random_noise (i ) for i in best_tensors ]
448+ working_unitcell = None
449+
450+ descent_dir = None
451+ working_gradient = None
452+ signal_reset_descent_dir = True
453+ count = 0
454+ random_noise_retries += 1
455+ old_descent_dir = descent_dir
456+ old_gradient = working_gradient
457+
458+ step_energies [random_noise_retries ] = []
459+ step_chi [random_noise_retries ] = []
460+ step_conv [random_noise_retries ] = []
461+ max_trunc_error_list [random_noise_retries ] = []
462+
463+ pbar .reset ()
464+ pbar .refresh ()
465+
466+ continue
412467
413468 working_gradient = [elem .conj () for elem in working_gradient_seq ]
414469
@@ -567,9 +622,10 @@ def random_noise(a):
567622 descent_dir = None
568623 working_gradient = None
569624 signal_reset_descent_dir = True
570- count = - 1
625+ count = 0
571626 random_noise_retries += 1
572- conv = jnp .inf
627+ old_descent_dir = descent_dir
628+ old_gradient = working_gradient
573629
574630 step_energies [random_noise_retries ] = []
575631 step_chi [random_noise_retries ] = []
@@ -578,6 +634,8 @@ def random_noise(a):
578634
579635 pbar .reset ()
580636 pbar .refresh ()
637+
638+ continue
581639 else :
582640 conv = 0
583641 else :
0 commit comments