|
1 | 1 | import collections |
2 | 2 | from collections import deque |
| 3 | +import datetime |
3 | 4 | from functools import partial |
4 | 5 | from os import PathLike |
5 | 6 | import time |
|
8 | 9 |
|
9 | 10 | from tqdm_loggable.auto import tqdm |
10 | 11 |
|
| 12 | +import numpy as np |
| 13 | + |
11 | 14 | import jax |
12 | 15 | from jax import jit |
13 | 16 | import jax.numpy as jnp |
|
22 | 25 | from varipeps.mapping import Map_To_PEPS_Model |
23 | 26 | from varipeps.ctmrg import CTMRGNotConvergedError, CTMRGGradientNotConvergedError |
24 | 27 | from varipeps.utils.random import PEPS_Random_Number_Generator |
| 28 | +from varipeps.utils.slurm import SlurmUtils |
25 | 29 |
|
26 | 30 | from .inner_function import ( |
27 | 31 | calc_ctmrg_expectation, |
@@ -283,6 +287,7 @@ def optimize_peps_network( |
283 | 287 | [PathLike, Sequence[jnp.ndarray], PEPS_Unit_Cell], None |
284 | 288 | ] = autosave_function, |
285 | 289 | additional_input: Dict[str, jnp.ndarray] = {}, |
| 290 | + slurm_restart_script: Optional[PathLike] = None, |
286 | 291 | ) -> Tuple[Sequence[jnp.ndarray], PEPS_Unit_Cell, Union[float, jnp.ndarray]]: |
287 | 292 | """ |
288 | 293 | Optimize a PEPS unitcell using a variational method. |
@@ -830,6 +835,32 @@ def random_noise(a): |
830 | 835 | best_unitcell = working_unitcell |
831 | 836 | best_run = random_noise_retries |
832 | 837 |
|
| 838 | + if ( |
| 839 | + slurm_restart_script is not None |
| 840 | + and (slurm_data := SlurmUtils.get_own_job_data()) is not None |
| 841 | + ): |
| 842 | + flatten_runtime = [j for i in step_runtime for j in i] |
| 843 | + runtime_mean = np.mean(flatten_runtime) |
| 844 | + runtime_std = np.std(flatten_runtime) |
| 845 | + |
| 846 | + if runtime_std > 0: |
| 847 | + remaining_slurm_time = ( |
| 848 | + slurm_data["TimeLimit"] - slurm_data["RunTime"] |
| 849 | + ) |
| 850 | + time_of_one_step = datetime.timedelta( |
| 851 | + seconds=runtime_mean + 3 * runtime_std |
| 852 | + ) |
| 853 | + |
| 854 | + if remaining_slurm_time < time_of_one_step: |
| 855 | + new_job_id = SlurmUtils.run_slurm_script(slurm_restart_script) |
| 856 | + if new_job_id is not None: |
| 857 | + tqdm.write(f"Started new Slurm job with ID {new_job_id:d}.") |
| 858 | + else: |
| 859 | + tqdm.write( |
| 860 | + "Failed to start new Slurm job or parse its job id." |
| 861 | + ) |
| 862 | + break |
| 863 | + |
833 | 864 | if working_value < best_value: |
834 | 865 | best_value = working_value |
835 | 866 | best_tensors = working_tensors |
|
0 commit comments