@@ -26,6 +26,7 @@ def parse_args():
2626 parser .add_argument ("-options" , "--options" , default = None , nargs = "+" )
2727 parser .add_argument ("-threads" , default = 2 , type = int )
2828 parser .add_argument ("-resume" , default = False , action = "store_true" )
29+ parser .add_argument ("-metric" , default = "avg_inc_acc" , choices = ["avg_inc_acc" , "last_acc" ])
2930
3031 return parser .parse_args ()
3132
@@ -45,6 +46,7 @@ def train_func(config, reporter):
4546
4647 total_avg_inc_acc = statistics .mean (all_acc )
4748 reporter (avg_inc_acc = total_avg_inc_acc )
49+ #reporter(last_acc=last_acc)
4850 return total_avg_inc_acc
4951
5052
@@ -54,24 +56,27 @@ def _get_abs_path(path):
5456 return os .path .join (os .path .dirname (os .path .realpath (__file__ )), path )
5557
5658
57- def analyse_ray_dump (ray_directory , topn ):
59+ def analyse_ray_dump (ray_directory , topn , metric = "avg_inc_acc" ):
60+ if metric not in ("avg_inc_acc" , "last_acc" ):
61+ raise NotImplementedError ("Unknown metric {}." .format (metric ))
62+
5863 ea = Analysis (ray_directory )
5964 trials_dataframe = ea .dataframe ()
60- trials_dataframe = trials_dataframe .sort_values (by = "avg_inc_acc" , ascending = False )
65+ trials_dataframe = trials_dataframe .sort_values (by = metric , ascending = False )
6166
6267 mapping_col_to_index = {}
6368 result_index = - 1
6469 for index , col in enumerate (trials_dataframe .columns ):
6570 if col .startswith ("config:" ):
6671 mapping_col_to_index [col [7 :]] = index
67- elif col == "avg_inc_acc" :
72+ elif col == metric :
6873 result_index = index
6974
7075 print ("Ray config: {}" .format (ray_directory ))
7176 print ("Best Config:" )
7277 print (
73- "avg_inc_acc : {} with {}." .format (
74- trials_dataframe .iloc [0 ][result_index ],
78+ "{} : {} with {}." .format (
79+ metric , trials_dataframe .iloc [0 ][result_index ],
7580 _get_line_results (trials_dataframe , 0 , mapping_col_to_index )
7681 )
7782 )
@@ -119,6 +124,9 @@ def get_tune_config(tune_options, options_files):
119124 with open (tune_options ) as f :
120125 options = yaml .load (f , Loader = yaml .FullLoader )
121126
127+ if "epochs" in options and options ["epochs" ] == 1 :
128+ raise ValueError ("Using only 1 epoch, must be a mistake." )
129+
122130 config = {}
123131 for k , v in options .items ():
124132 if not k .startswith ("var:" ):
@@ -141,6 +149,12 @@ def main():
141149 if args .tune is not None :
142150 config = get_tune_config (args .tune , args .options )
143151 config ["threads" ] = args .threads
152+
153+ try :
154+ os .system ("echo '\ek{}_gridsearch\e\\ '" .format (args .tune ))
155+ except :
156+ pass
157+
144158 ray .init ()
145159 tune .run (
146160 train_func ,
@@ -158,10 +172,12 @@ def main():
158172 args .ray_directory = os .path .join (args .ray_directory , args .tune .rstrip ("/" ).split ("/" )[- 1 ])
159173
160174 if args .tune is not None :
161- print ("\n \n " , args .tune , "\n \n " )
175+ print ("\n \n " , args .tune , args . options , "\n \n " )
162176
163177 if args .ray_directory is not None :
164- best_config = analyse_ray_dump (_get_abs_path (args .ray_directory ), args .topn )
178+ best_config = analyse_ray_dump (
179+ _get_abs_path (args .ray_directory ), args .topn , metric = args .metric
180+ )
165181
166182 if args .output_options :
167183 with open (args .output_options , "w+" ) as f :
0 commit comments