1313from runtype import dataclass
1414
1515from .utils import safezip , run_as_daemon
16+ from .thread_utils import ThreadedYielder
1617from .databases .database_types import IKey , NumericType , PrecisionType , StringType
1718from .table_segment import TableSegment
1819from .tracking import create_end_event_json , create_start_event_json , send_event_json , is_tracking_enabled
@@ -124,19 +125,22 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
124125 f"size: { table1 .approximate_size ()} "
125126 )
126127
128+ ti = ThreadedYielder (self .max_threadpool_size )
127129 # Bisect (split) the table into segments, and diff them recursively.
128- yield from self ._bisect_and_diff_tables ( table1 , table2 )
130+ ti . submit ( self ._bisect_and_diff_tables , ti , table1 , table2 )
129131
130132 # Now we check for the second min-max, to diff the portions we "missed".
131133 min_key2 , max_key2 = self ._parse_key_range_result (key_type , next (key_ranges ))
132134
133135 if min_key2 < min_key1 :
134136 pre_tables = [t .new (min_key = min_key2 , max_key = min_key1 ) for t in (table1 , table2 )]
135- yield from self ._bisect_and_diff_tables ( * pre_tables )
137+ ti . submit ( self ._bisect_and_diff_tables , ti , * pre_tables )
136138
137139 if max_key2 > max_key1 :
138140 post_tables = [t .new (min_key = max_key1 , max_key = max_key2 ) for t in (table1 , table2 )]
139- yield from self ._bisect_and_diff_tables (* post_tables )
141+ ti .submit (self ._bisect_and_diff_tables , ti , * post_tables )
142+
143+ yield from ti
140144
141145 except BaseException as e : # Catch KeyboardInterrupt too
142146 error = e
@@ -218,7 +222,7 @@ def _validate_and_adjust_columns(self, table1, table2):
218222 "If encoding/formatting differs between databases, it may result in false positives."
219223 )
220224
221- def _bisect_and_diff_tables (self , table1 , table2 , level = 0 , max_rows = None ):
225+ def _bisect_and_diff_tables (self , ti : ThreadedYielder , table1 , table2 , level = 0 , max_rows = None ):
222226 assert table1 .is_bounded and table2 .is_bounded
223227
224228 if max_rows is None :
@@ -242,8 +246,7 @@ def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None):
242246
243247 logger .info (". " * level + f"Diff found { len (diff )} different rows." )
244248 self .stats ["rows_downloaded" ] = self .stats .get ("rows_downloaded" , 0 ) + max (len (rows1 ), len (rows2 ))
245- yield from diff
246- return
249+ return diff
247250
248251 # Choose evenly spaced checkpoints (according to min_key and max_key)
249252 checkpoints = table1 .choose_checkpoints (self .bisection_factor - 1 )
@@ -253,15 +256,10 @@ def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None):
253256 segmented2 = table2 .segment_by_checkpoints (checkpoints )
254257
255258 # Recursively compare each pair of corresponding segments between table1 and table2
256- diff_iters = [
257- self ._diff_tables (t1 , t2 , level + 1 , i + 1 , len (segmented1 ))
258- for i , (t1 , t2 ) in enumerate (safezip (segmented1 , segmented2 ))
259- ]
260-
261- for res in self ._thread_map (list , diff_iters ):
262- yield from res
259+ for i , (t1 , t2 ) in enumerate (safezip (segmented1 , segmented2 )):
260+ ti .submit (self ._diff_tables , ti , t1 , t2 , level + 1 , i + 1 , len (segmented1 ), priority = level )
263261
264- def _diff_tables (self , table1 , table2 , level = 0 , segment_index = None , segment_count = None ):
262+ def _diff_tables (self , ti : ThreadedYielder , table1 , table2 , level = 0 , segment_index = None , segment_count = None ):
265263 logger .info (
266264 ". " * level + f"Diffing segment { segment_index } /{ segment_count } , "
267265 f"key-range: { table1 .min_key } ..{ table2 .max_key } , "
@@ -275,8 +273,7 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun
275273 if BENCHMARK :
276274 max_rows_from_keys = max (table1 .max_key - table1 .min_key , table2 .max_key - table2 .min_key )
277275 if max_rows_from_keys < self .bisection_threshold :
278- yield from self ._bisect_and_diff_tables (table1 , table2 , level = level , max_rows = max_rows_from_keys )
279- return
276+ return self ._bisect_and_diff_tables (ti , table1 , table2 , level = level , max_rows = max_rows_from_keys )
280277
281278 (count1 , checksum1 ), (count2 , checksum2 ) = self ._threaded_call ("count_and_checksum" , [table1 , table2 ])
282279
@@ -293,7 +290,7 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun
293290 self .stats ["table2_count" ] = self .stats .get ("table2_count" , 0 ) + count2
294291
295292 if checksum1 != checksum2 :
296- yield from self ._bisect_and_diff_tables (table1 , table2 , level = level , max_rows = max (count1 , count2 ))
293+ return self ._bisect_and_diff_tables (ti , table1 , table2 , level = level , max_rows = max (count1 , count2 ))
297294
298295 def _thread_map (self , func , iterable ):
299296 if not self .threaded :
0 commit comments