Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 162 additions & 86 deletions flarestack/core/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,123 @@ class OverfluctuationError(Exception):
pass


class PickleCache:
def __init__(self, pickle_path: Path, background_only: bool = False):
self.path = pickle_path
# self.pickle_path.mkdir(parents=True, exist_ok=True)
self.merged_path = self.path / "merged"
self.merged_path.mkdir(parents=True, exist_ok=True)
self.background_only = background_only

def clean_merged_data(self):
"""Function to clear cache of all data"""
# remove all files inside merged_path using Pathlib
if self.merged_path.exists():
for file in self.merged_path.iterdir():
if file.is_file():
file.unlink()
logger.debug(f"Removed all files from {self.merged_path}")

def get_subdirs(self):
return [x for x in os.listdir(self.path) if x[0] != "." and x != "merged"]

def merge_datadict(self, merged, pending_data):
"""Merge the content of pending_data into merged."""
for key, info in pending_data.items():
if isinstance(info, list):
# Append the list to the existing one. We assume that on the left-hand-side we always have a list, so we ignore the type checking.
merged[key] += info
elif isinstance(info, dict):
for param_name, params in info.items():
try:
merged[key][param_name] += params
except KeyError as m:
logger.warning(
f"Keys [{key}][{param_name}] not found in \n {merged}"
)
raise KeyError(m)
else:
raise TypeError(
f"Unexpected type for key {key}: {type(info)}. Expected list or dict."
)

def merge_and_load_subdir(self, subdir_name):
"""Merge and load data from a single subdirectory."""
subdir = os.path.join(self.path, subdir_name)

files = os.listdir(subdir)

# Map one dir to one pickle
merged_file_path = os.path.join(self.merged_path, subdir_name + ".pkl")
# Load previously merged data, if it exists.
if os.path.isfile(merged_file_path):
logger.debug(f"loading merged data from {merged_file_path}")
with open(merged_file_path, "rb") as mp:
merged_data = Pickle.load(mp)
else:
merged_data = {}

for filename in files:
pending_file = os.path.join(subdir, filename)

try:
with open(pending_file, "rb") as f:
data = Pickle.load(f)
except (EOFError, IsADirectoryError):
logger.warning("Failed loading: {0}".format(pending_file))
continue
# Remove file immediately. This can lead to undesired results because if the program crashes or gets terminated in the process, we will have removed files before writing the merged data. However, delaying the removal would create the opposite problem: the file is merged but not removed, and potentially merged a second time at the next run.
os.remove(pending_file)

if merged_data == {}:
merged_data = data
else:
self.merge_datadict(merged_data, data)

# Save merged data.
with open(merged_file_path, "wb") as mp:
Pickle.dump(merged_data, mp)

return merged_data

def merge_and_load(self, output_dict: dict):
# Loop over all injection scale subdirectories.
scales_subdirs = self.get_subdirs()

background_label = scale_shortener(0.0)

for subdir_name in scales_subdirs:
try:
scale_label = scale_shortener(float(subdir_name))
except ValueError as e:
# If analysis paths are nested, i.e. we have analyses ana1 and ana1/sub1, the ana1/sub1 directory will be scanned, but should be skipped. Ideally the user should avoid nesting analysis directories, but there is no safeguard against this behaviour.
logger.debug(
f"Skipping subdirectory {subdir_name} as it does not represent a valid scale. Parent directory: {self.path}"
)
continue

if self.background_only and scale_label != background_label:
# skip non-background trials for background_only mode
continue

pending_data = self.merge_and_load_subdir(subdir_name)

if pending_data:
n_pending = len(pending_data["TS"])

if scale_label == background_label and background_label in output_dict:
logger.info(
f"Appending f{n_pending} background data to {len(output_dict[background_label]['TS'])} existing trials ({scale_label=})"
)
self.merge_datadict(output_dict[background_label], pending_data)
else:
output_dict[scale_label] = pending_data
if self.background_only:
logger.info(
f"Loading {n_pending} background trials ({scale_label=})"
)


class ResultsHandler(object):
def __init__(
self,
Expand All @@ -45,10 +162,14 @@ def __init__(
do_disc=True,
bias_error="std",
sigma_thresholds=[3.0, 5.0],
background_source=None,
):
self.sources = load_catalogue(rh_dict["catalogue"])

self.name = rh_dict["name"]

self.background_source = background_source

self.mh_name = rh_dict["mh_name"]

self._inj_dict = rh_dict["inj_dict"]
Expand All @@ -59,11 +180,23 @@ def __init__(
self.maxfev = rh_dict.get("maxfev", 800)

self.results = dict()

self.pickle_output_dir = name_pickle_output_dir(self.name)
self.pickle_output_dir_bg = (
name_pickle_output_dir(self.background_source)
if self.background_source
else None
)

self.plot_path = Path(plot_output_dir(self.name))
self.pickle_cache = PickleCache(Path(self.pickle_output_dir))

self.pickle_cache_bg = (
PickleCache(Path(self.pickle_output_dir_bg), background_only=True)
if self.background_source
else None
)

self.merged_dir = os.path.join(self.pickle_output_dir, "merged")
self.plot_path = Path(plot_output_dir(self.name))

self.allow_extrapolation = rh_dict.get("allow_extrapolated_sensitivity", True)

Expand Down Expand Up @@ -101,7 +234,8 @@ def __init__(
"extrapolated": False,
}

# Load injection ladder values
# Load injection ladder values.
# Builds a dictionary mapping the injection scale to the content of the trials.
try:
self.inj = self.load_injection_values()
except FileNotFoundError as err:
Expand All @@ -112,17 +246,32 @@ def __init__(
self.valid = False
return

# Load and merge the trial results
try:
self.merge_pickle_data()
self.pickle_cache.merge_and_load(output_dict=self.results)
# Load the background trials. Will override the existing one.
if self.pickle_cache_bg is not None:
print("NOTE!!!! Loading BG")
Copy link

Copilot AI Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace this print statement with a logger call (e.g., logger.info) to maintain consistent logging practices.

Suggested change
print("NOTE!!!! Loading BG")
logger.info("NOTE!!!! Loading BG")

Copilot uses AI. Check for mistakes.
self.pickle_cache_bg.merge_and_load(output_dict=self.results)
else:
print("NOTE!!!! No BG pickle cache")
Copy link

Copilot AI Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace this print statement with a logger call (e.g., logger.warning) for consistency and better control over output.

Suggested change
print("NOTE!!!! No BG pickle cache")
logger.warning("NOTE!!!! No BG pickle cache")

Copilot uses AI. Check for mistakes.
if not self.results:
logger.warning("No data was found by ResultsHandler object! \n")
logger.warning(
"Tried root directory: \n {0} \n ".format(self.pickle_output_dir)
)
sys.exit()
if not scale_shortener(0.0) in self.results:
logger.error(
f"No key equal to '0' in results! Keys are {self.results.keys()}"
)

sys.exit()
Comment on lines +262 to +268
Copy link

Copilot AI Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid calling sys.exit during initialization; consider raising an exception so callers can handle errors more gracefully.

Suggested change
sys.exit()
if not scale_shortener(0.0) in self.results:
logger.error(
f"No key equal to '0' in results! Keys are {self.results.keys()}"
)
sys.exit()
raise RuntimeError("No data was found by ResultsHandler object!")
if not scale_shortener(0.0) in self.results:
logger.error(
f"No key equal to '0' in results! Keys are {self.results.keys()}"
)
raise KeyError(f"No key equal to '0' in results! Keys are {self.results.keys()}")

Copilot uses AI. Check for mistakes.
Comment on lines +262 to +268
Copy link

Copilot AI Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid calling sys.exit during initialization; consider raising an exception so callers can handle errors more gracefully.

Suggested change
sys.exit()
if not scale_shortener(0.0) in self.results:
logger.error(
f"No key equal to '0' in results! Keys are {self.results.keys()}"
)
sys.exit()
raise RuntimeError("No data was found by ResultsHandler object!")
if not scale_shortener(0.0) in self.results:
logger.error(
f"No key equal to '0' in results! Keys are {self.results.keys()}"
)
raise RuntimeError("No key equal to '0' in results!")

Copilot uses AI. Check for mistakes.

except FileNotFoundError:
logger.warning(f"No files found at {self.pickle_output_dir}")

# auxiliary parameters
# self.sorted_scales = sorted(self.results.keys())
self.scale_values = sorted(
[float(j) for j in self.results.keys()]
) # replaces self.scales_float
self.scale_values = sorted([float(j) for j in self.results.keys()])
self.scale_labels = [scale_shortener(i) for i in self.scale_values]

logger.info(f"Injection scales: {self.scale_values}")
Expand All @@ -131,6 +280,7 @@ def __init__(
# Determine the injection scales
try:
self.find_ns_scale()
self.plot_bias()
except ValueError as e:
logger.warning(f"RuntimeError for ns scale factor: \n {e}")
except IndexError as e:
Expand All @@ -142,10 +292,6 @@ def __init__(
if len(self.scale_values) == 1 and self.scale_values[0] == 0:
self.make_plots(self.scale_labels[0])

# Create fit bias plots
# this expects flux_to_ns to be set
self.plot_bias()

if do_sens:
try:
self.find_sensitivity()
Expand Down Expand Up @@ -282,12 +428,8 @@ def nu_astronomy(self, flux, e_pdf_dict):
return calculate_astronomy(flux, e_pdf_dict)

def clean_merged_data(self):
"""Function to clear cache of all data"""
try:
for f in os.listdir(self.merged_dir):
os.remove(self.merged_dir + f)
except OSError:
pass
"""Clean merged data from pickle cache, only for main analysis. Do not touch the background cache."""
self.pickle_cache.clean_merged_data()

def load_injection_values(self):
"""Function to load the values used in injection, so that a
Expand Down Expand Up @@ -315,72 +457,6 @@ def load_injection_values(self):

return inj_values

def merge_pickle_data(self):
all_sub_dirs = [
x
for x in os.listdir(self.pickle_output_dir)
if x[0] != "." and x != "merged"
]

try:
os.makedirs(self.merged_dir)
except OSError:
pass

for sub_dir_name in all_sub_dirs:
sub_dir = os.path.join(self.pickle_output_dir, sub_dir_name)

files = os.listdir(sub_dir)

merged_path = os.path.join(self.merged_dir, sub_dir_name + ".pkl")

if os.path.isfile(merged_path):
logger.debug(f"loading merged data from {merged_path}")
with open(merged_path, "rb") as mp:
merged_data = Pickle.load(mp)
else:
merged_data = {}

for filename in files:
path = os.path.join(sub_dir, filename)

try:
with open(path, "rb") as f:
data = Pickle.load(f)
except (EOFError, IsADirectoryError):
logger.warning("Failed loading: {0}".format(path))
continue
os.remove(path)

if merged_data == {}:
merged_data = data
else:
for key, info in data.items():
if isinstance(info, list):
merged_data[key] += info
else:
for param_name, params in info.items():
try:
merged_data[key][param_name] += params
except KeyError as m:
logger.warning(
f"Keys [{key}][{param_name}] not found in \n {merged_data}"
)
raise KeyError(m)

with open(merged_path, "wb") as mp:
Pickle.dump(merged_data, mp)

if len(list(merged_data.keys())) > 0:
self.results[scale_shortener(float(sub_dir_name))] = merged_data

if len(list(self.results.keys())) == 0:
logger.warning("No data was found by ResultsHandler object! \n")
logger.warning(
"Tried root directory: \n {0} \n ".format(self.pickle_output_dir)
)
sys.exit()

def find_ns_scale(self):
"""Find the number of neutrinos corresponding to flux"""
try:
Expand Down Expand Up @@ -685,7 +761,7 @@ def find_disc_potential(self):
)

logger.info(
f"Scale: {scale}, TS_threshold: {disc_threshold}, n_trials: {len(ts_array)} => overfluctuations {frac=}"
f"Scale: {scale}, TS_threshold: {disc_threshold:.1f}, n_trials: {len(ts_array)} => overfluctuations {frac=:.4f}"
)

y[zval].append(frac)
Expand Down
Loading