@@ -37,6 +37,113 @@ class OverfluctuationError(Exception):
3737 pass
3838
3939
40+ class PickleCache :
41+ def __init__ (self , pickle_path : Path , background_only : bool = False ):
42+ self .path = pickle_path
43+ # self.pickle_path.mkdir(parents=True, exist_ok=True)
44+ self .merged_path = self .path / "merged"
45+ self .merged_path .mkdir (parents = True , exist_ok = True )
46+ self .background_only = background_only
47+
48+ def clean_merged_data (self ):
49+ """Function to clear cache of all data"""
50+ # remove all files inside merged_path using Pathlib
51+ if self .merged_path .exists ():
52+ for file in self .merged_path .iterdir ():
53+ if file .is_file ():
54+ file .unlink ()
55+ logger .debug (f"Removed all files from { self .merged_path } " )
56+
57+ def get_subdirs (self ):
58+ return [
59+ x
60+ for x in os .listdir (self .path )
61+ if x [0 ] != "." and x != "merged"
62+ ]
63+
64+ def merge_datadict (self , merged : dict [str , list | dict ], pending_data : dict [str , list | dict ]):
65+ """Merge the content of pending_data into merged."""
66+ for key , info in pending_data .items ():
67+ if isinstance (info , list ):
68+ # Append the list to the existing one.
69+ merged [key ] += info
70+ elif isinstance (info , dict ):
71+ for param_name , params in info .items ():
72+ try :
73+ merged [key ][param_name ] += params
74+ except KeyError as m :
75+ logger .warning (
76+ f"Keys [{ key } ][{ param_name } ] not found in \n { merged } "
77+ )
78+ raise KeyError (m )
79+ else :
80+ raise TypeError (
81+ f"Unexpected type for key { key } : { type (info )} . Expected list or dict."
82+ )
83+
84+
85+ def merge_and_load_subdir (self , subdir_name ):
86+ """Merge and load data from a single subdirectory."""
87+ subdir = os .path .join (self .path , subdir_name )
88+
89+ files = os .listdir (subdir )
90+
91+ # Map one dir to one pickle
92+ merged_file_path = os .path .join (self .merged_path , subdir_name + ".pkl" )
93+ # Load previously merged data, if it exists.
94+ if os .path .isfile (merged_file_path ):
95+ logger .debug (f"loading merged data from { merged_file_path } " )
96+ with open (merged_file_path , "rb" ) as mp :
97+ merged_data = Pickle .load (mp )
98+ else :
99+ merged_data = {}
100+
101+ for filename in files :
102+ pending_file = os .path .join (subdir , filename )
103+
104+ try :
105+ with open (pending_file , "rb" ) as f :
106+ data = Pickle .load (f )
107+ except (EOFError , IsADirectoryError ):
108+ logger .warning ("Failed loading: {0}" .format (pending_file ))
109+ continue
110+ # Remove file immediately. This can lead to undesired results because if the program crashes or gets terminated in the process, we will have removed files before writing the merged data. However, delaying the removal would create the opposite problem: the file is merged but not removed, and potentially merged a second time at the next run.
111+ os .remove (pending_file )
112+
113+ if merged_data == {}:
114+ merged_data = data
115+ else :
116+ self .merge_datadict (merged_data , data )
117+
118+ # Save merged data.
119+ with open (merged_file_path , "wb" ) as mp :
120+ Pickle .dump (merged_data , mp )
121+
122+ return merged_data
123+
124+ def merge_and_load (self , output_dict : dict ):
125+ # Loop over all injection scale subdirectories.
126+ scales_subdirs = self .get_subdirs ()
127+
128+ background_label = scale_shortener (0.0 )
129+
130+ for subdir_name in scales_subdirs :
131+ scale_label = scale_shortener (float (subdir_name ))
132+
133+ if self .background_only and scale_label != background_label :
134+ # skip non-background trials
135+ continue
136+
137+ pending_data = self .merge_and_load_subdir (subdir_name )
138+
139+ if pending_data :
140+ if scale_label == background_label and background_label in output_dict :
141+ self .merge_datadict (output_dict [background_label ], pending_data )
142+ else :
143+ output_dict [scale_label ] = pending_data
144+
145+
146+
40147class ResultsHandler (object ):
41148 def __init__ (
42149 self ,
@@ -45,10 +152,17 @@ def __init__(
45152 do_disc = True ,
46153 bias_error = "std" ,
47154 sigma_thresholds = [3.0 , 5.0 ],
155+ background_from = None
48156 ):
49157 self .sources = load_catalogue (rh_dict ["catalogue" ])
50158
51159 self .name = rh_dict ["name" ]
160+
161+ if background_from is not None :
162+ self .background_from = background_from
163+ else :
164+ self .background_from = rh_dict ["name" ]
165+
52166 self .mh_name = rh_dict ["mh_name" ]
53167
54168 self ._inj_dict = rh_dict ["inj_dict" ]
@@ -59,11 +173,14 @@ def __init__(
59173 self .maxfev = rh_dict .get ("maxfev" , 800 )
60174
61175 self .results = dict ()
176+
62177 self .pickle_output_dir = name_pickle_output_dir (self .name )
178+ self .pickle_output_dir_bg = name_pickle_output_dir (self .background_from )
63179
64- self .plot_path = Path (plot_output_dir (self .name ))
180+ self .pickle_cache = PickleCache (Path (self .pickle_output_dir ))
181+ self .pickle_cache_bg = PickleCache (Path (self .pickle_output_dir_bg ))
65182
66- self .merged_dir = os . path . join ( self .pickle_output_dir , "merged" )
183+ self .plot_path = Path ( plot_output_dir ( self .name ) )
67184
68185 self .allow_extrapolation = rh_dict .get ("allow_extrapolated_sensitivity" , True )
69186
@@ -101,7 +218,8 @@ def __init__(
101218 "extrapolated" : False ,
102219 }
103220
104- # Load injection ladder values
221+ # Load injection ladder values.
222+ # Builds a dictionary mapping the injection scale to the content of the trials.
105223 try :
106224 self .inj = self .load_injection_values ()
107225 except FileNotFoundError as err :
@@ -112,17 +230,17 @@ def __init__(
112230 self .valid = False
113231 return
114232
115- # Load and merge the trial results
116233 try :
117- self .merge_pickle_data ()
234+ self .pickle_cache .merge_and_load (output_dict = self .results )
235+ # Load the background trials. Will override the existing one.
236+ self .pickle_cache_bg .merge_and_load (output_dict = self .results )
118237 except FileNotFoundError :
119238 logger .warning (f"No files found at { self .pickle_output_dir } " )
120239
121240 # auxiliary parameters
122- # self.sorted_scales = sorted(self.results.keys())
123241 self .scale_values = sorted (
124242 [float (j ) for j in self .results .keys ()]
125- ) # replaces self.scales_float
243+ )
126244 self .scale_labels = [scale_shortener (i ) for i in self .scale_values ]
127245
128246 logger .info (f"Injection scales: { self .scale_values } " )
@@ -131,17 +249,14 @@ def __init__(
131249 # Determine the injection scales
132250 try :
133251 self .find_ns_scale ()
252+ self .plot_bias ()
134253 except ValueError as e :
135254 logger .warning (f"RuntimeError for ns scale factor: \n { e } " )
136255 except IndexError as e :
137256 logger .warning (
138257 f"IndexError for ns scale factor. Only background trials? \n { e } "
139258 )
140259
141- # Create fit bias plots
142- # this expects flux_to_ns to be set
143- self .plot_bias ()
144-
145260 if do_sens :
146261 try :
147262 self .find_sensitivity ()
@@ -278,12 +393,9 @@ def nu_astronomy(self, flux, e_pdf_dict):
278393 return calculate_astronomy (flux , e_pdf_dict )
279394
280395 def clean_merged_data (self ):
281- """Function to clear cache of all data"""
282- try :
283- for f in os .listdir (self .merged_dir ):
284- os .remove (self .merged_dir + f )
285- except OSError :
286- pass
396+ """Clean merged data from pickle cache, only for main analysis. Do not touch the background cache."""
397+ self .pickle_cache .clean_merged_data ()
398+
287399
288400 def load_injection_values (self ):
289401 """Function to load the values used in injection, so that a
@@ -311,25 +423,33 @@ def load_injection_values(self):
311423
312424 return inj_values
313425
314- def merge_pickle_data (self ):
426+
427+ def merge_and_load_pickle_data (self ):
428+ # NOTE:
429+ # self.pickle_output_path
430+ # self.merged_dir = self.pickle_output_path / "merged"
431+
432+
433+ # Loop over all subdirectories, one for each injection scale, containing one pickle per trial.
315434 all_sub_dirs = [
316435 x
317- for x in os .listdir (self .pickle_output_dir )
436+ for x in os .listdir (self .path )
318437 if x [0 ] != "." and x != "merged"
319438 ]
320-
439+ # Create a "merged" directory, that will contain a single pickle with many trials per injection scale.
321440 try :
322441 os .makedirs (self .merged_dir )
323442 except OSError :
324443 pass
325444
326445 for sub_dir_name in all_sub_dirs :
327- sub_dir = os .path .join (self .pickle_output_dir , sub_dir_name )
446+ sub_dir = os .path .join (self .path , sub_dir_name )
328447
329448 files = os .listdir (sub_dir )
330449
450+ # Map one dir to one pickle
331451 merged_path = os .path .join (self .merged_dir , sub_dir_name + ".pkl" )
332-
452+ # Load previously merged data, if it exists.
333453 if os .path .isfile (merged_path ):
334454 logger .debug (f"loading merged data from { merged_path } " )
335455 with open (merged_path , "rb" ) as mp :
@@ -338,15 +458,16 @@ def merge_pickle_data(self):
338458 merged_data = {}
339459
340460 for filename in files :
341- path = os .path .join (sub_dir , filename )
461+ pending_file = os .path .join (sub_dir , filename )
342462
343463 try :
344- with open (path , "rb" ) as f :
464+ with open (pending_file , "rb" ) as f :
345465 data = Pickle .load (f )
346466 except (EOFError , IsADirectoryError ):
347- logger .warning ("Failed loading: {0}" .format (path ))
467+ logger .warning ("Failed loading: {0}" .format (pending_file ))
348468 continue
349- os .remove (path )
469+ # This can be "dangerous" because if the program crashes or gets terminated, we will have removed files before writing the merged data.
470+ os .remove (pending_file )
350471
351472 if merged_data == {}:
352473 merged_data = data
@@ -364,13 +485,15 @@ def merge_pickle_data(self):
364485 )
365486 raise KeyError (m )
366487
488+ # Save merged data.
367489 with open (merged_path , "wb" ) as mp :
368490 Pickle .dump (merged_data , mp )
369491
370- if len (list (merged_data .keys ())) > 0 :
492+ # Load merged data in results.
493+ if merged_data :
371494 self .results [scale_shortener (float (sub_dir_name ))] = merged_data
372495
373- if len ( list ( self .results . keys ())) == 0 :
496+ if not self .results :
374497 logger .warning ("No data was found by ResultsHandler object! \n " )
375498 logger .warning (
376499 "Tried root directory: \n {0} \n " .format (self .pickle_output_dir )
0 commit comments