Skip to content

Commit fad8e5c

Browse files
committed
attempt at decoupling bg
1 parent 49e712b commit fad8e5c

File tree

1 file changed

+151
-28
lines changed

1 file changed

+151
-28
lines changed

flarestack/core/results.py

Lines changed: 151 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,113 @@ class OverfluctuationError(Exception):
3737
pass
3838

3939

40+
class PickleCache:
41+
def __init__(self, pickle_path: Path, background_only: bool = False):
42+
self.path = pickle_path
43+
# self.pickle_path.mkdir(parents=True, exist_ok=True)
44+
self.merged_path = self.path / "merged"
45+
self.merged_path.mkdir(parents=True, exist_ok=True)
46+
self.background_only = background_only
47+
48+
def clean_merged_data(self):
49+
"""Function to clear cache of all data"""
50+
# remove all files inside merged_path using Pathlib
51+
if self.merged_path.exists():
52+
for file in self.merged_path.iterdir():
53+
if file.is_file():
54+
file.unlink()
55+
logger.debug(f"Removed all files from {self.merged_path}")
56+
57+
def get_subdirs(self):
58+
return [
59+
x
60+
for x in os.listdir(self.path)
61+
if x[0] != "." and x != "merged"
62+
]
63+
64+
def merge_datadict(self, merged: dict[str, list | dict], pending_data: dict[str, list | dict]):
65+
"""Merge the content of pending_data into merged."""
66+
for key, info in pending_data.items():
67+
if isinstance(info, list):
68+
# Append the list to the existing one.
69+
merged[key] += info
70+
elif isinstance(info, dict):
71+
for param_name, params in info.items():
72+
try:
73+
merged[key][param_name] += params
74+
except KeyError as m:
75+
logger.warning(
76+
f"Keys [{key}][{param_name}] not found in \n {merged}"
77+
)
78+
raise KeyError(m)
79+
else:
80+
raise TypeError(
81+
f"Unexpected type for key {key}: {type(info)}. Expected list or dict."
82+
)
83+
84+
85+
def merge_and_load_subdir(self, subdir_name):
86+
"""Merge and load data from a single subdirectory."""
87+
subdir = os.path.join(self.path, subdir_name)
88+
89+
files = os.listdir(subdir)
90+
91+
# Map one dir to one pickle
92+
merged_file_path = os.path.join(self.merged_path, subdir_name + ".pkl")
93+
# Load previously merged data, if it exists.
94+
if os.path.isfile(merged_file_path):
95+
logger.debug(f"loading merged data from {merged_file_path}")
96+
with open(merged_file_path, "rb") as mp:
97+
merged_data = Pickle.load(mp)
98+
else:
99+
merged_data = {}
100+
101+
for filename in files:
102+
pending_file = os.path.join(subdir, filename)
103+
104+
try:
105+
with open(pending_file, "rb") as f:
106+
data = Pickle.load(f)
107+
except (EOFError, IsADirectoryError):
108+
logger.warning("Failed loading: {0}".format(pending_file))
109+
continue
110+
# Remove file immediately. This can lead to undesired results because if the program crashes or gets terminated in the process, we will have removed files before writing the merged data. However, delaying the removal would create the opposite problem: the file is merged but not removed, and potentially merged a second time at the next run.
111+
os.remove(pending_file)
112+
113+
if merged_data == {}:
114+
merged_data = data
115+
else:
116+
self.merge_datadict(merged_data, data)
117+
118+
# Save merged data.
119+
with open(merged_file_path, "wb") as mp:
120+
Pickle.dump(merged_data, mp)
121+
122+
return merged_data
123+
124+
def merge_and_load(self, output_dict: dict):
125+
# Loop over all injection scale subdirectories.
126+
scales_subdirs = self.get_subdirs()
127+
128+
background_label = scale_shortener(0.0)
129+
130+
for subdir_name in scales_subdirs:
131+
scale_label = scale_shortener(float(subdir_name))
132+
133+
if self.background_only and scale_label != background_label:
134+
# skip non-background trials
135+
continue
136+
137+
pending_data = self.merge_and_load_subdir(subdir_name)
138+
139+
if pending_data:
140+
if scale_label == background_label and background_label in output_dict:
141+
self.merge_datadict(output_dict[background_label], pending_data)
142+
else:
143+
output_dict[scale_label] = pending_data
144+
145+
146+
40147
class ResultsHandler(object):
41148
def __init__(
42149
self,
@@ -45,10 +152,17 @@ def __init__(
45152
do_disc=True,
46153
bias_error="std",
47154
sigma_thresholds=[3.0, 5.0],
155+
background_from=None
48156
):
49157
self.sources = load_catalogue(rh_dict["catalogue"])
50158

51159
self.name = rh_dict["name"]
160+
161+
if background_from is not None:
162+
self.background_from = background_from
163+
else:
164+
self.background_from = rh_dict["name"]
165+
52166
self.mh_name = rh_dict["mh_name"]
53167

54168
self._inj_dict = rh_dict["inj_dict"]
@@ -59,11 +173,14 @@ def __init__(
59173
self.maxfev = rh_dict.get("maxfev", 800)
60174

61175
self.results = dict()
176+
62177
self.pickle_output_dir = name_pickle_output_dir(self.name)
178+
self.pickle_output_dir_bg = name_pickle_output_dir(self.background_from)
63179

64-
self.plot_path = Path(plot_output_dir(self.name))
180+
self.pickle_cache = PickleCache(Path(self.pickle_output_dir))
181+
self.pickle_cache_bg = PickleCache(Path(self.pickle_output_dir_bg))
65182

66-
self.merged_dir = os.path.join(self.pickle_output_dir, "merged")
183+
self.plot_path = Path(plot_output_dir(self.name))
67184

68185
self.allow_extrapolation = rh_dict.get("allow_extrapolated_sensitivity", True)
69186

@@ -101,7 +218,8 @@ def __init__(
101218
"extrapolated": False,
102219
}
103220

104-
# Load injection ladder values
221+
# Load injection ladder values.
222+
# Builds a dictionary mapping the injection scale to the content of the trials.
105223
try:
106224
self.inj = self.load_injection_values()
107225
except FileNotFoundError as err:
@@ -112,17 +230,17 @@ def __init__(
112230
self.valid = False
113231
return
114232

115-
# Load and merge the trial results
116233
try:
117-
self.merge_pickle_data()
234+
self.pickle_cache.merge_and_load(output_dict=self.results)
235+
# Load the background trials. Will override the existing one.
236+
self.pickle_cache_bg.merge_and_load(output_dict=self.results)
118237
except FileNotFoundError:
119238
logger.warning(f"No files found at {self.pickle_output_dir}")
120239

121240
# auxiliary parameters
122-
# self.sorted_scales = sorted(self.results.keys())
123241
self.scale_values = sorted(
124242
[float(j) for j in self.results.keys()]
125-
) # replaces self.scales_float
243+
)
126244
self.scale_labels = [scale_shortener(i) for i in self.scale_values]
127245

128246
logger.info(f"Injection scales: {self.scale_values}")
@@ -131,17 +249,14 @@ def __init__(
131249
# Determine the injection scales
132250
try:
133251
self.find_ns_scale()
252+
self.plot_bias()
134253
except ValueError as e:
135254
logger.warning(f"RuntimeError for ns scale factor: \n {e}")
136255
except IndexError as e:
137256
logger.warning(
138257
f"IndexError for ns scale factor. Only background trials? \n {e}"
139258
)
140259

141-
# Create fit bias plots
142-
# this expects flux_to_ns to be set
143-
self.plot_bias()
144-
145260
if do_sens:
146261
try:
147262
self.find_sensitivity()
@@ -278,12 +393,9 @@ def nu_astronomy(self, flux, e_pdf_dict):
278393
return calculate_astronomy(flux, e_pdf_dict)
279394

280395
def clean_merged_data(self):
281-
"""Function to clear cache of all data"""
282-
try:
283-
for f in os.listdir(self.merged_dir):
284-
os.remove(self.merged_dir + f)
285-
except OSError:
286-
pass
396+
"""Clean merged data from pickle cache, only for main analysis. Do not touch the background cache."""
397+
self.pickle_cache.clean_merged_data()
398+
287399

288400
def load_injection_values(self):
289401
"""Function to load the values used in injection, so that a
@@ -311,25 +423,33 @@ def load_injection_values(self):
311423

312424
return inj_values
313425

314-
def merge_pickle_data(self):
426+
427+
def merge_and_load_pickle_data(self):
428+
# NOTE:
429+
# self.pickle_output_path
430+
# self.merged_dir = self.pickle_output_path / "merged"
431+
432+
433+
# Loop over all subdirectories, one for each injection scale, containing one pickle per trial.
315434
all_sub_dirs = [
316435
x
317-
for x in os.listdir(self.pickle_output_dir)
436+
for x in os.listdir(self.path)
318437
if x[0] != "." and x != "merged"
319438
]
320-
439+
# Create a "merged" directory, that will contain a single pickle with many trials per injection scale.
321440
try:
322441
os.makedirs(self.merged_dir)
323442
except OSError:
324443
pass
325444

326445
for sub_dir_name in all_sub_dirs:
327-
sub_dir = os.path.join(self.pickle_output_dir, sub_dir_name)
446+
sub_dir = os.path.join(self.path, sub_dir_name)
328447

329448
files = os.listdir(sub_dir)
330449

450+
# Map one dir to one pickle
331451
merged_path = os.path.join(self.merged_dir, sub_dir_name + ".pkl")
332-
452+
# Load previously merged data, if it exists.
333453
if os.path.isfile(merged_path):
334454
logger.debug(f"loading merged data from {merged_path}")
335455
with open(merged_path, "rb") as mp:
@@ -338,15 +458,16 @@ def merge_pickle_data(self):
338458
merged_data = {}
339459

340460
for filename in files:
341-
path = os.path.join(sub_dir, filename)
461+
pending_file = os.path.join(sub_dir, filename)
342462

343463
try:
344-
with open(path, "rb") as f:
464+
with open(pending_file, "rb") as f:
345465
data = Pickle.load(f)
346466
except (EOFError, IsADirectoryError):
347-
logger.warning("Failed loading: {0}".format(path))
467+
logger.warning("Failed loading: {0}".format(pending_file))
348468
continue
349-
os.remove(path)
469+
# This can be "dangerous" because if the program crashes or gets terminated, we will have removed files before writing the merged data.
470+
os.remove(pending_file)
350471

351472
if merged_data == {}:
352473
merged_data = data
@@ -364,13 +485,15 @@ def merge_pickle_data(self):
364485
)
365486
raise KeyError(m)
366487

488+
# Save merged data.
367489
with open(merged_path, "wb") as mp:
368490
Pickle.dump(merged_data, mp)
369491

370-
if len(list(merged_data.keys())) > 0:
492+
# Load merged data in results.
493+
if merged_data:
371494
self.results[scale_shortener(float(sub_dir_name))] = merged_data
372495

373-
if len(list(self.results.keys())) == 0:
496+
if not self.results:
374497
logger.warning("No data was found by ResultsHandler object! \n")
375498
logger.warning(
376499
"Tried root directory: \n {0} \n ".format(self.pickle_output_dir)

0 commit comments

Comments
 (0)