@@ -911,15 +911,19 @@ def generate_init_population(
911911 return fit_population
912912
913913
914- def generation_step_callback (toolbox , ngen , population ):
914+ def generation_step_callback (
915+ run , gtp_scores , user_callback_per_generation , ngen , population
916+ ):
915917 """Called after each generation step cycle in train().
916918
917- :param toolbox: toolbox of the evolutionary algorithm.
919+ :param run: number of the current run
920+ :param gtp_scores: gtp_scores as of start of this run
921+ :param user_callback_per_generation: a user provided callback that is called
922+ after each training generation. It not None called like this:
923+ user_callback_per_generation(run, gtp_scores, ngen, population)
918924 :param ngen: the number of the current generation.
919925 :param population: the current population after generation ngen.
920926 """
921- run = toolbox .get_run ()
922- gtp_scores = toolbox .get_gtp_scores ()
923927 top_counter = print_population (run , ngen , population )
924928 top_gps = sorted (
925929 top_counter .keys (), key = attrgetter ("fitness" ), reverse = True
@@ -929,15 +933,18 @@ def generation_step_callback(toolbox, ngen, population):
929933 save_population (
930934 run , ngen , top_gps , generation_gtp_scores
931935 )
936+ if user_callback_per_generation :
937+ # user provided callback
938+ user_callback_per_generation (run , gtp_scores , ngen , population )
932939
933940
934941def find_graph_patterns (
935- sparql , run , gtp_scores ):
942+ sparql , run , gtp_scores ,
943+ user_callback_per_generation = None ,
944+ ):
936945 timeout = calibrate_query_timeout (sparql )
937946
938947 toolbox = deap .base .Toolbox ()
939- toolbox .register ("get_run" , lambda : run )
940- toolbox .register ("get_gtp_scores" , lambda : gtp_scores )
941948
942949 toolbox .register (
943950 "mate" , mate
@@ -952,7 +959,10 @@ def find_graph_patterns(
952959 )
953960 toolbox .register (
954961 "evaluate" , evaluate , sparql , timeout , gtp_scores )
955- toolbox .register ("generation_step_callback" , generation_step_callback )
962+ toolbox .register (
963+ "generation_step_callback" ,
964+ generation_step_callback , run , gtp_scores , user_callback_per_generation
965+ )
956966
957967
958968 population = generate_init_population (
@@ -985,11 +995,15 @@ def _find_graph_pattern_coverage_run(
985995 coverage_counts ,
986996 gtp_scores ,
987997 patterns ,
998+ user_callback_per_generation = None ,
999+ user_callback_per_run = None ,
9881000):
9891001 min_fitness = calc_min_fitness (gtp_scores , min_score )
9901002
9911003 ngen , res_pop , hall_of_fame , toolbox = find_graph_patterns (
992- sparql , run , gtp_scores )
1004+ sparql , run , gtp_scores ,
1005+ user_callback_per_generation = user_callback_per_generation ,
1006+ )
9931007
9941008 # TODO: coverage patterns should be chosen based on similarity
9951009 new_best_patterns = []
@@ -1085,6 +1099,11 @@ def _find_graph_pattern_coverage_run(
10851099 )
10861100 set_symlink (fp , config .SYMLINK_CURRENT_RES_RUN )
10871101
1102+ if user_callback_per_run :
1103+ user_callback_per_run (
1104+ run , gtp_scores , new_best_patterns , coverage_counts
1105+ )
1106+
10881107 return new_best_patterns , coverage_counts , gtp_scores
10891108
10901109
@@ -1096,6 +1115,8 @@ def find_graph_pattern_coverage(
10961115 max_runs = config .NRUNS ,
10971116 runs_no_improvement = config .NRUNS_NO_IMPROVEMENT ,
10981117 error_retries = config .ERROR_RETRIES ,
1118+ user_callback_per_generation = None ,
1119+ user_callback_per_run = None ,
10991120):
11001121 assert isinstance (ground_truth_pairs , tuple )
11011122
@@ -1135,6 +1156,8 @@ def find_graph_pattern_coverage(
11351156 coverage_counts ,
11361157 gtp_scores ,
11371158 patterns ,
1159+ user_callback_per_generation = user_callback_per_generation ,
1160+ user_callback_per_run = user_callback_per_run ,
11381161 )
11391162 new_best_patterns , coverage_counts , gtp_scores = res
11401163 patterns .update ({pat : run for pat , run in new_best_patterns })
0 commit comments