Skip to content

Commit 7a7bf12

Browse files
committed
Implement method to query slurm data and store it in the unitcell files along with the config
1 parent a908887 commit 7a7bf12

File tree

3 files changed

+111
-0
lines changed

3 files changed

+111
-0
lines changed

varipeps/peps/unitcell.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import collections
88
from dataclasses import dataclass
9+
import datetime
910
import pathlib
1011
from os import PathLike
1112
import subprocess
@@ -1086,6 +1087,22 @@ def save_to_group(self, grp: h5py.Group, store_config: bool = True) -> None:
10861087
if varipeps.git_tag is not None:
10871088
grp_version.attrs["git_tag"] = varipeps.git_tag
10881089

1090+
if (
1091+
slurm_data := varipeps.utils.slurm.SlurmUtils.get_own_job_data()
1092+
) is not None:
1093+
grp_slurm = grp.create_group("slurm")
1094+
1095+
for k, v in slurm_data.items():
1096+
if isinstance(v, datetime.datetime):
1097+
grp_slurm.attrs[k] = v.isoformat()
1098+
elif isinstance(v, datetime.timedelta):
1099+
grp_slurm.attrs[k] = v.total_seconds()
1100+
elif isinstance(v, dict):
1101+
for k2, v2 in v.items():
1102+
grp_slurm.attrs[f"{k}_{k2}"] = v2
1103+
else:
1104+
grp_slurm.attrs[k] = v
1105+
10891106
@classmethod
10901107
def load_from_file(
10911108
cls: Type[T_PEPS_Unit_Cell],

varipeps/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from . import func_cache
33
from . import random
44
from . import projector_dict
5+
from . import slurm
56
from . import svd

varipeps/utils/slurm.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import datetime
2+
import os
3+
import subprocess
4+
5+
6+
class SlurmUtils:
7+
@staticmethod
8+
def parse_special_fields(job_data):
9+
job_data = job_data.copy()
10+
11+
for field in (
12+
"SubmitTime",
13+
"EligibleTime",
14+
"AccrueTime",
15+
"StartTime",
16+
"EndTime",
17+
"LastSchedEval",
18+
):
19+
if (entry := job_data.get(field)) is not None:
20+
job_data[field] = datetime.datetime.fromisoformat(entry)
21+
22+
for field in ("RunTime", "TimeLimit", "DelayBoot"):
23+
if (entry := job_data.get(field)) is not None:
24+
entry = entry.split("-")
25+
if len(entry) == 2:
26+
days, time = entry
27+
else:
28+
(time,) = entry
29+
days = 0
30+
31+
hours, minutes, seconds = time.split(":")
32+
33+
job_data[field] = datetime.timedelta(
34+
days=int(days),
35+
hours=int(hours),
36+
minutes=int(minutes),
37+
seconds=int(seconds),
38+
)
39+
40+
if (entry := job_data.get("TRES")) is not None:
41+
entry = entry.split(",")
42+
if len(entry) > 0:
43+
job_data["TRES"] = dict(e.split("=", 1) for e in entry)
44+
45+
for field in job_data:
46+
try:
47+
entry = int(job_data[field])
48+
job_data[field] = entry
49+
except (ValueError, TypeError):
50+
pass
51+
52+
try:
53+
entry = job_data[field]
54+
if not isinstance(entry, int):
55+
entry = float(entry)
56+
job_data[field] = entry
57+
except (ValueError, TypeError):
58+
pass
59+
60+
return job_data
61+
62+
@classmethod
63+
def get_job_data(cls, job_id):
64+
job_id = int(job_id)
65+
66+
p = subprocess.run(
67+
["scontrol", "show", "job", f"{job_id:d}"], capture_output=True, text=True
68+
)
69+
70+
if p.returncode != 0:
71+
return None
72+
73+
job_data = p.stdout.split()
74+
75+
slice_comb_list = []
76+
for i, e in enumerate(job_data):
77+
if "=" not in e:
78+
slice_comb_list[-1] = slice(slice_comb_list[-1].start, i + 1)
79+
else:
80+
slice_comb_list.append(slice(i, i + 1))
81+
82+
job_data = ["".join(job_data[s]) for s in slice_comb_list]
83+
job_data = dict(e.split("=", 1) for e in job_data)
84+
85+
job_data = cls.parse_special_fields(job_data)
86+
87+
return job_data
88+
89+
@classmethod
90+
def get_own_job_data(cls):
91+
if (job_id := os.environ.get("SLURM_JOB_ID")) is not None:
92+
return cls.get_job_data(job_id)
93+
return None

0 commit comments

Comments
 (0)