11import datetime
2+ import math
23import os
4+ import pathlib
5+ import sys
36import subprocess
7+ import textwrap
48
59
610class SlurmUtils :
@@ -15,8 +19,9 @@ def parse_special_fields(job_data):
1519 "StartTime" ,
1620 "EndTime" ,
1721 "LastSchedEval" ,
22+ "PreemptEligibleTime" ,
1823 ):
19- if (entry := job_data .get (field )) is not None :
24+ if (entry := job_data .get (field )) is not None and entry != "Unknown" :
2025 job_data [field ] = datetime .datetime .fromisoformat (entry )
2126
2227 for field in ("RunTime" , "TimeLimit" , "DelayBoot" ):
@@ -37,10 +42,11 @@ def parse_special_fields(job_data):
3742 seconds = int (seconds ),
3843 )
3944
40- if (entry := job_data .get ("TRES" )) is not None :
41- entry = entry .split ("," )
42- if len (entry ) > 0 :
43- job_data ["TRES" ] = dict (e .split ("=" , 1 ) for e in entry )
45+ for field in ("TRES" , "ReqTRES" , "AllocTRES" ):
46+ if (entry := job_data .get (field )) is not None :
47+ entry = entry .split ("," )
48+ if len (entry ) > 0 :
49+ job_data [field ] = dict (e .split ("=" , 1 ) for e in entry )
4450
4551 for field in job_data :
4652 try :
@@ -93,40 +99,13 @@ def get_own_job_data(cls):
9399 return None
94100
95101 @staticmethod
96- def run_slurm_script (path ):
102+ def run_slurm_script (path , cwd = None ):
103+ cwd = pathlib .Path (cwd ).resolve ()
104+
97105 p = subprocess .run (
98- ["scontrol " , "show" , "job" , f" { job_id :d } " ], capture_output = True , text = True
106+ ["sbatch " , str ( path ) ], capture_output = True , text = True , cwd = cwd
99107 )
100108
101- if p .returncode != 0 :
102- return None
103-
104- job_data = p .stdout .split ()
105-
106- slice_comb_list = []
107- for i , e in enumerate (job_data ):
108- if "=" not in e :
109- slice_comb_list [- 1 ] = slice (slice_comb_list [- 1 ].start , i + 1 )
110- else :
111- slice_comb_list .append (slice (i , i + 1 ))
112-
113- job_data = ["" .join (job_data [s ]) for s in slice_comb_list ]
114- job_data = dict (e .split ("=" , 1 ) for e in job_data )
115-
116- job_data = cls .parse_special_fields (job_data )
117-
118- return job_data
119-
120- @classmethod
121- def get_own_job_data (cls ):
122- if (job_id := os .environ .get ("SLURM_JOB_ID" )) is not None :
123- return cls .get_job_data (job_id )
124- return None
125-
126- @staticmethod
127- def run_slurm_script (path ):
128- p = subprocess .run (["sbatch" , str (path )], capture_output = True , text = True )
129-
130109 if p .returncode != 0 :
131110 return None
132111
@@ -136,3 +115,96 @@ def run_slurm_script(path):
136115 job_id = None
137116
138117 return job_id
118+
119+ @staticmethod
120+ def generate_restart_scripts (
121+ slurm_script_path ,
122+ python_script_path ,
123+ restart_state_file ,
124+ slurm_data ,
125+ executable = None ,
126+ ):
127+ TEMPLATE_PYTHON = textwrap .dedent (
128+ """\
129+ #!/usr/bin/env python3
130+ import argparse
131+ import pathlib
132+ import varipeps
133+
134+ parser = argparse.ArgumentParser()
135+ parser.add_argument('filename', type=pathlib.Path)
136+ args = parser.parse_args()
137+
138+ varipeps.optimization.restart_from_state_file(args.filename)
139+ """
140+ )
141+
142+ TEMPLATE_SLURM = textwrap .dedent (
143+ """\
144+ #!/bin/bash
145+
146+ #SBATCH --partition={partition}
147+ #SBATCH --qos={qos}
148+ #SBATCH --job-name={job_name}
149+ #SBATCH --ntasks={ntasks:d}
150+ #SBATCH --nodes={nodes:d}
151+ #SBATCH --cpus-per-task={ncpus:d}
152+ #SBATCH --mem={mem}
153+ #SBATCH --time={time_limit}
154+ #SBATCH --mail-type=FAIL,END
155+
156+ "{executable}" "{python_script}" "{state_file}"
157+ """
158+ )
159+
160+ python_script_path = pathlib .Path (python_script_path ).resolve ()
161+
162+ restart_state_file = pathlib .Path (restart_state_file ).resolve ()
163+
164+ if executable is None :
165+ executable = pathlib .Path (sys .executable ).absolute ()
166+
167+ if (tres := slurm_data .get ("ReqTRES" )) is not None :
168+ mem = tres ["mem" ]
169+ else :
170+ tres = slurm_data ["TRES" ]
171+ mem = tres ["mem" ]
172+
173+ time_limit = slurm_data ["TimeLimit" ]
174+ if time_limit .days > 0 :
175+ time_limit_diff = time_limit - datetime .timedelta (days = time_limit .days )
176+ else :
177+ time_limit_diff = time_limit
178+
179+ time_limit_hours = math .floor (time_limit_diff / datetime .timedelta (hours = 1 ))
180+ time_limit_diff -= datetime .timedelta (hours = time_limit_hours )
181+
182+ time_limit_minutes = math .floor (time_limit_diff / datetime .timedelta (minutes = 1 ))
183+ time_limit_diff -= datetime .timedelta (minutes = time_limit_minutes )
184+
185+ time_limit_seconds = math .floor (time_limit_diff .total_seconds ())
186+
187+ if time_limit .days > 0 :
188+ time_limit_str = f"{ time_limit .days } -{ time_limit_hours :02d} :{ time_limit_minutes :02d} :{ time_limit_seconds :02d} "
189+ else :
190+ time_limit_str = f"{ time_limit_hours :02d} :{ time_limit_minutes :02d} :{ time_limit_seconds :02d} "
191+
192+ slurm_file_content = TEMPLATE_SLURM .format (
193+ partition = slurm_data ["Partition" ],
194+ qos = slurm_data ["QOS" ],
195+ job_name = f"{ slurm_data ['JobName' ]} _restarted" ,
196+ ntasks = slurm_data ["NumTasks" ],
197+ nodes = slurm_data ["NumNodes" ],
198+ ncpus = slurm_data ["CPUs/Task" ],
199+ mem = mem ,
200+ time_limit = time_limit_str ,
201+ executable = str (executable ),
202+ python_script = str (python_script_path ),
203+ state_file = restart_state_file ,
204+ )
205+
206+ with python_script_path .open ("w" ) as f :
207+ f .write (TEMPLATE_PYTHON )
208+
209+ with pathlib .Path (slurm_script_path ).open ("w" ) as f :
210+ f .write (slurm_file_content )
0 commit comments