Skip to content

Commit 4b4c3b8

Browse files
committed
Add feature to run slurm restart script if the runtime seems to be not enough for the current job
1 parent 7a7bf12 commit 4b4c3b8

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

varipeps/optimization/optimizer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import collections
22
from collections import deque
3+
import datetime
34
from functools import partial
45
from os import PathLike
56
import time
@@ -8,6 +9,8 @@
89

910
from tqdm_loggable.auto import tqdm
1011

12+
import numpy as np
13+
1114
import jax
1215
from jax import jit
1316
import jax.numpy as jnp
@@ -22,6 +25,7 @@
2225
from varipeps.mapping import Map_To_PEPS_Model
2326
from varipeps.ctmrg import CTMRGNotConvergedError, CTMRGGradientNotConvergedError
2427
from varipeps.utils.random import PEPS_Random_Number_Generator
28+
from varipeps.utils.slurm import SlurmUtils
2529

2630
from .inner_function import (
2731
calc_ctmrg_expectation,
@@ -283,6 +287,7 @@ def optimize_peps_network(
283287
[PathLike, Sequence[jnp.ndarray], PEPS_Unit_Cell], None
284288
] = autosave_function,
285289
additional_input: Dict[str, jnp.ndarray] = {},
290+
slurm_restart_script: Optional[PathLike] = None,
286291
) -> Tuple[Sequence[jnp.ndarray], PEPS_Unit_Cell, Union[float, jnp.ndarray]]:
287292
"""
288293
Optimize a PEPS unitcell using a variational method.
@@ -830,6 +835,32 @@ def random_noise(a):
830835
best_unitcell = working_unitcell
831836
best_run = random_noise_retries
832837

838+
if (
839+
slurm_restart_script is not None
840+
and (slurm_data := SlurmUtils.get_own_job_data()) is not None
841+
):
842+
flatten_runtime = [j for i in step_runtime for j in i]
843+
runtime_mean = np.mean(flatten_runtime)
844+
runtime_std = np.std(flatten_runtime)
845+
846+
if runtime_std > 0:
847+
remaining_slurm_time = (
848+
slurm_data["TimeLimit"] - slurm_data["RunTime"]
849+
)
850+
time_of_one_step = datetime.timedelta(
851+
seconds=runtime_mean + 3 * runtime_std
852+
)
853+
854+
if remaining_slurm_time < time_of_one_step:
855+
new_job_id = SlurmUtils.run_slurm_script(slurm_restart_script)
856+
if new_job_id is not None:
857+
tqdm.write(f"Started new Slurm job with ID {new_job_id:d}.")
858+
else:
859+
tqdm.write(
860+
"Failed to start new Slurm job or parse its job id."
861+
)
862+
break
863+
833864
if working_value < best_value:
834865
best_value = working_value
835866
best_tensors = working_tensors

varipeps/utils/slurm.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,48 @@ def get_own_job_data(cls):
9191
if (job_id := os.environ.get("SLURM_JOB_ID")) is not None:
9292
return cls.get_job_data(job_id)
9393
return None
94+
95+
@staticmethod
96+
def run_slurm_script(path):
97+
p = subprocess.run(
98+
["scontrol", "show", "job", f"{job_id:d}"], capture_output=True, text=True
99+
)
100+
101+
if p.returncode != 0:
102+
return None
103+
104+
job_data = p.stdout.split()
105+
106+
slice_comb_list = []
107+
for i, e in enumerate(job_data):
108+
if "=" not in e:
109+
slice_comb_list[-1] = slice(slice_comb_list[-1].start, i + 1)
110+
else:
111+
slice_comb_list.append(slice(i, i + 1))
112+
113+
job_data = ["".join(job_data[s]) for s in slice_comb_list]
114+
job_data = dict(e.split("=", 1) for e in job_data)
115+
116+
job_data = cls.parse_special_fields(job_data)
117+
118+
return job_data
119+
120+
@classmethod
121+
def get_own_job_data(cls):
122+
if (job_id := os.environ.get("SLURM_JOB_ID")) is not None:
123+
return cls.get_job_data(job_id)
124+
return None
125+
126+
@staticmethod
127+
def run_slurm_script(path):
128+
p = subprocess.run(["sbatch", str(path)], capture_output=True, text=True)
129+
130+
if p.returncode != 0:
131+
return None
132+
133+
try:
134+
job_id = int(p.stdout.split()[-1])
135+
except ValueError:
136+
job_id = None
137+
138+
return job_id

0 commit comments

Comments
 (0)