Skip to content

Commit 53b91fa

Browse files
committed
Add code which restarts slurm job from state file
1 parent cb50ab5 commit 53b91fa

File tree

4 files changed

+164
-51
lines changed

4 files changed

+164
-51
lines changed

varipeps/config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ class Wavevector_Type(IntEnum):
3939
TWO_PI_SYMMETRIC = auto() #: Use interval [-2pi, 2pi) for q vectors
4040

4141

42+
@unique
43+
class Slurm_Restart_Mode(IntEnum):
44+
DISABLED = (
45+
auto()
46+
) #: Disable automatic restart of slurm job if maximal runtime limit is reached
47+
WRITE_RESTART_SCRIPT = (
48+
auto()
49+
) #: Write restart script but do not submit new slurm job
50+
AUTOMATIC_RESTART = auto() #: Write restart script and start new slurm job with it
51+
52+
4253
@dataclass
4354
@register_pytree_node_class
4455
class VariPEPS_Config:
@@ -190,6 +201,8 @@ class VariPEPS_Config:
190201
:obj:`scipy.optimize.basinhopping`. See this function for details.
191202
spiral_wavevector_type (:obj:`Wavevector_Type`):
192203
Type of wavevector to be used (only positive/symmetric interval/...).
204+
slurm_restart_mode (:obj:`Slurm_Restart_Mode`):
205+
Mode of operation to restart slurm job if maximal runtime is reached.
193206
"""
194207

195208
# AD config
@@ -265,6 +278,9 @@ class VariPEPS_Config:
265278
# Spiral PEPS
266279
spiral_wavevector_type: Wavevector_Type = Wavevector_Type.TWO_PI_POSITIVE_ONLY
267280

281+
# Slurm
282+
slurm_restart_mode: Slurm_Restart_Mode = Slurm_Restart_Mode.DISABLED
283+
268284
def update(self, name: str, value: Any) -> NoReturn:
269285
self.__setattr__(name, value)
270286

@@ -349,6 +365,7 @@ class ConfigModuleWrapper:
349365
"Line_Search_Methods",
350366
"Projector_Method",
351367
"Wavevector_Type",
368+
"Slurm_Restart_Mode",
352369
"VariPEPS_Config",
353370
"config",
354371
}

varipeps/optimization/optimizer.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from jax.flatten_util import ravel_pytree
2222

2323
from varipeps import varipeps_config, varipeps_global_state
24-
from varipeps.config import Optimizing_Methods
24+
from varipeps.config import Optimizing_Methods, Slurm_Restart_Mode
2525
from varipeps.peps import PEPS_Unit_Cell
2626
from varipeps.expectation import Expectation_Model
2727
from varipeps.config import Projector_Method
@@ -403,7 +403,6 @@ def optimize_peps_network(
403403
] = autosave_function,
404404
additional_input: Dict[str, jnp.ndarray] = {},
405405
restart_state: Dict[str, Any] = {},
406-
slurm_restart_script: Optional[PathLike] = None,
407406
) -> Tuple[Sequence[jnp.ndarray], PEPS_Unit_Cell, Union[float, jnp.ndarray]]:
408407
"""
409408
Optimize a PEPS unitcell using a variational method.
@@ -542,6 +541,9 @@ def random_noise(a):
542541
else:
543542
varipeps_global_state.ctmrg_projector_method = None
544543

544+
slurm_restart_written = False
545+
slurm_new_job_id = None
546+
545547
with tqdm(desc="Optimizing PEPS state", initial=count) as pbar:
546548
while count < varipeps_config.optimizer_max_steps:
547549
runtime_start = time.perf_counter()
@@ -1067,26 +1069,37 @@ def random_noise(a):
10671069
best_run = random_noise_retries
10681070

10691071
if (
1070-
slurm_restart_script is not None
1072+
varipeps_config.slurm_restart_mode is not Slurm_Restart_Mode.DISABLED
10711073
and (slurm_data := SlurmUtils.get_own_job_data()) is not None
10721074
):
1073-
flatten_runtime = [j for i in step_runtime for j in i]
1075+
flatten_runtime = [j for i in step_runtime for j in step_runtime[i]]
10741076
runtime_mean = np.mean(flatten_runtime)
10751077
runtime_std = np.std(flatten_runtime)
10761078

1077-
if runtime_std > 0:
1078-
remaining_slurm_time = (
1079-
slurm_data["TimeLimit"] - slurm_data["RunTime"]
1080-
)
1081-
time_of_one_step = datetime.timedelta(
1082-
seconds=runtime_mean + 3 * runtime_std
1079+
remaining_slurm_time = slurm_data["TimeLimit"] - slurm_data["RunTime"]
1080+
time_of_one_step = datetime.timedelta(
1081+
seconds=runtime_mean + 3 * runtime_std
1082+
)
1083+
1084+
if remaining_slurm_time < time_of_one_step:
1085+
SlurmUtils.generate_restart_scripts(
1086+
f"{str(autosave_filename)}.restart.slurm",
1087+
f"{str(autosave_filename)}.restart.py",
1088+
f"{str(autosave_filename)}.restartable",
1089+
slurm_data,
10831090
)
10841091

1085-
if remaining_slurm_time < time_of_one_step:
1086-
new_job_id = SlurmUtils.run_slurm_script(slurm_restart_script)
1087-
if new_job_id is not None:
1088-
tqdm.write(f"Started new Slurm job with ID {new_job_id:d}.")
1089-
else:
1092+
slurm_restart_written = True
1093+
1094+
if (
1095+
varipeps_config.slurm_restart_mode
1096+
is Slurm_Restart_Mode.AUTOMATIC_RESTART
1097+
):
1098+
slurm_new_job_id = SlurmUtils.run_slurm_script(
1099+
f"{str(autosave_filename)}.restart.slurm",
1100+
slurm_data["WorkDir"],
1101+
)
1102+
if slurm_new_job_id is None:
10901103
tqdm.write(
10911104
"Failed to start new Slurm job or parse its job id."
10921105
)
@@ -1124,6 +1137,11 @@ def random_noise(a):
11241137

11251138
print(f"Best energy result found: {best_value}")
11261139

1140+
if slurm_restart_written:
1141+
print("Wrote script to restart optimizer job with Slurm.")
1142+
if slurm_new_job_id is not None:
1143+
print(f"Started new Slurm job with ID {slurm_new_job_id:d}.")
1144+
11271145
return OptimizeResult(
11281146
success=True,
11291147
x=best_tensors,
@@ -1136,6 +1154,8 @@ def random_noise(a):
11361154
step_conv=step_conv,
11371155
step_runtime=step_runtime,
11381156
best_run=best_run,
1157+
slurm_restart_written=slurm_restart_written,
1158+
slurm_new_job_id=slurm_new_job_id,
11391159
)
11401160

11411161

varipeps/peps/unitcell.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,10 @@ def load_from_group(
12301230
config_dict["spiral_wavevector_type"] = varipeps.config.Wavevector_Type(
12311231
config_dict["spiral_wavevector_type"]
12321232
)
1233+
if config_dict.get("slurm_restart_mode"):
1234+
config_dict["slurm_restart_mode"] = varipeps.config.Slurm_Restart_Mode(
1235+
config_dict["slurm_restart_mode"]
1236+
)
12331237

12341238
return cls(
12351239
data=data, real_ix=real_ix, real_iy=real_iy

varipeps/utils/slurm.py

Lines changed: 108 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import datetime
2+
import math
23
import os
4+
import pathlib
5+
import sys
36
import subprocess
7+
import textwrap
48

59

610
class SlurmUtils:
@@ -15,8 +19,9 @@ def parse_special_fields(job_data):
1519
"StartTime",
1620
"EndTime",
1721
"LastSchedEval",
22+
"PreemptEligibleTime",
1823
):
19-
if (entry := job_data.get(field)) is not None:
24+
if (entry := job_data.get(field)) is not None and entry != "Unknown":
2025
job_data[field] = datetime.datetime.fromisoformat(entry)
2126

2227
for field in ("RunTime", "TimeLimit", "DelayBoot"):
@@ -37,10 +42,11 @@ def parse_special_fields(job_data):
3742
seconds=int(seconds),
3843
)
3944

40-
if (entry := job_data.get("TRES")) is not None:
41-
entry = entry.split(",")
42-
if len(entry) > 0:
43-
job_data["TRES"] = dict(e.split("=", 1) for e in entry)
45+
for field in ("TRES", "ReqTRES", "AllocTRES"):
46+
if (entry := job_data.get(field)) is not None:
47+
entry = entry.split(",")
48+
if len(entry) > 0:
49+
job_data[field] = dict(e.split("=", 1) for e in entry)
4450

4551
for field in job_data:
4652
try:
@@ -93,40 +99,13 @@ def get_own_job_data(cls):
9399
return None
94100

95101
@staticmethod
96-
def run_slurm_script(path):
102+
def run_slurm_script(path, cwd=None):
103+
cwd = pathlib.Path(cwd).resolve()
104+
97105
p = subprocess.run(
98-
["scontrol", "show", "job", f"{job_id:d}"], capture_output=True, text=True
106+
["sbatch", str(path)], capture_output=True, text=True, cwd=cwd
99107
)
100108

101-
if p.returncode != 0:
102-
return None
103-
104-
job_data = p.stdout.split()
105-
106-
slice_comb_list = []
107-
for i, e in enumerate(job_data):
108-
if "=" not in e:
109-
slice_comb_list[-1] = slice(slice_comb_list[-1].start, i + 1)
110-
else:
111-
slice_comb_list.append(slice(i, i + 1))
112-
113-
job_data = ["".join(job_data[s]) for s in slice_comb_list]
114-
job_data = dict(e.split("=", 1) for e in job_data)
115-
116-
job_data = cls.parse_special_fields(job_data)
117-
118-
return job_data
119-
120-
@classmethod
121-
def get_own_job_data(cls):
122-
if (job_id := os.environ.get("SLURM_JOB_ID")) is not None:
123-
return cls.get_job_data(job_id)
124-
return None
125-
126-
@staticmethod
127-
def run_slurm_script(path):
128-
p = subprocess.run(["sbatch", str(path)], capture_output=True, text=True)
129-
130109
if p.returncode != 0:
131110
return None
132111

@@ -136,3 +115,96 @@ def run_slurm_script(path):
136115
job_id = None
137116

138117
return job_id
118+
119+
@staticmethod
120+
def generate_restart_scripts(
121+
slurm_script_path,
122+
python_script_path,
123+
restart_state_file,
124+
slurm_data,
125+
executable=None,
126+
):
127+
TEMPLATE_PYTHON = textwrap.dedent(
128+
"""\
129+
#!/usr/bin/env python3
130+
import argparse
131+
import pathlib
132+
import varipeps
133+
134+
parser = argparse.ArgumentParser()
135+
parser.add_argument('filename', type=pathlib.Path)
136+
args = parser.parse_args()
137+
138+
varipeps.optimization.restart_from_state_file(args.filename)
139+
"""
140+
)
141+
142+
TEMPLATE_SLURM = textwrap.dedent(
143+
"""\
144+
#!/bin/bash
145+
146+
#SBATCH --partition={partition}
147+
#SBATCH --qos={qos}
148+
#SBATCH --job-name={job_name}
149+
#SBATCH --ntasks={ntasks:d}
150+
#SBATCH --nodes={nodes:d}
151+
#SBATCH --cpus-per-task={ncpus:d}
152+
#SBATCH --mem={mem}
153+
#SBATCH --time={time_limit}
154+
#SBATCH --mail-type=FAIL,END
155+
156+
"{executable}" "{python_script}" "{state_file}"
157+
"""
158+
)
159+
160+
python_script_path = pathlib.Path(python_script_path).resolve()
161+
162+
restart_state_file = pathlib.Path(restart_state_file).resolve()
163+
164+
if executable is None:
165+
executable = pathlib.Path(sys.executable).absolute()
166+
167+
if (tres := slurm_data.get("ReqTRES")) is not None:
168+
mem = tres["mem"]
169+
else:
170+
tres = slurm_data["TRES"]
171+
mem = tres["mem"]
172+
173+
time_limit = slurm_data["TimeLimit"]
174+
if time_limit.days > 0:
175+
time_limit_diff = time_limit - datetime.timedelta(days=time_limit.days)
176+
else:
177+
time_limit_diff = time_limit
178+
179+
time_limit_hours = math.floor(time_limit_diff / datetime.timedelta(hours=1))
180+
time_limit_diff -= datetime.timedelta(hours=time_limit_hours)
181+
182+
time_limit_minutes = math.floor(time_limit_diff / datetime.timedelta(minutes=1))
183+
time_limit_diff -= datetime.timedelta(minutes=time_limit_minutes)
184+
185+
time_limit_seconds = math.floor(time_limit_diff.total_seconds())
186+
187+
if time_limit.days > 0:
188+
time_limit_str = f"{time_limit.days}-{time_limit_hours:02d}:{time_limit_minutes:02d}:{time_limit_seconds:02d}"
189+
else:
190+
time_limit_str = f"{time_limit_hours:02d}:{time_limit_minutes:02d}:{time_limit_seconds:02d}"
191+
192+
slurm_file_content = TEMPLATE_SLURM.format(
193+
partition=slurm_data["Partition"],
194+
qos=slurm_data["QOS"],
195+
job_name=f"{slurm_data['JobName']}_restarted",
196+
ntasks=slurm_data["NumTasks"],
197+
nodes=slurm_data["NumNodes"],
198+
ncpus=slurm_data["CPUs/Task"],
199+
mem=mem,
200+
time_limit=time_limit_str,
201+
executable=str(executable),
202+
python_script=str(python_script_path),
203+
state_file=restart_state_file,
204+
)
205+
206+
with python_script_path.open("w") as f:
207+
f.write(TEMPLATE_PYTHON)
208+
209+
with pathlib.Path(slurm_script_path).open("w") as f:
210+
f.write(slurm_file_content)

0 commit comments

Comments
 (0)