Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit af890da

Browse files
committed
Re-wrote threading to use a thread-pool + priority queue.
- KeyboardInterrupt is now handled correctly. - Resulting iterator is now better behaved (--limit works a lot better)
1 parent b54e925 commit af890da

File tree

3 files changed

+94
-18
lines changed

3 files changed

+94
-18
lines changed

data_diff/__main__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ def _main(
197197
bisection_threshold=bisection_threshold,
198198
threaded=threaded,
199199
max_threadpool_size=threads and threads * 2,
200-
debug=debug,
201200
)
202201

203202
if database1 is None or database2 is None:

data_diff/diff_tables.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from runtype import dataclass
1414

1515
from .utils import safezip, run_as_daemon
16+
from .thread_utils import ThreadedYielder
1617
from .databases.database_types import IKey, NumericType, PrecisionType, StringType
1718
from .table_segment import TableSegment
1819
from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled
@@ -124,19 +125,22 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
124125
f"size: {table1.approximate_size()}"
125126
)
126127

128+
ti = ThreadedYielder(self.max_threadpool_size)
127129
# Bisect (split) the table into segments, and diff them recursively.
128-
yield from self._bisect_and_diff_tables(table1, table2)
130+
ti.submit(self._bisect_and_diff_tables, ti, table1, table2)
129131

130132
# Now we check for the second min-max, to diff the portions we "missed".
131133
min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges))
132134

133135
if min_key2 < min_key1:
134136
pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)]
135-
yield from self._bisect_and_diff_tables(*pre_tables)
137+
ti.submit(self._bisect_and_diff_tables, ti, *pre_tables)
136138

137139
if max_key2 > max_key1:
138140
post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)]
139-
yield from self._bisect_and_diff_tables(*post_tables)
141+
ti.submit(self._bisect_and_diff_tables, ti, *post_tables)
142+
143+
yield from ti
140144

141145
except BaseException as e: # Catch KeyboardInterrupt too
142146
error = e
@@ -218,7 +222,7 @@ def _validate_and_adjust_columns(self, table1, table2):
218222
"If encoding/formatting differs between databases, it may result in false positives."
219223
)
220224

221-
def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None):
225+
def _bisect_and_diff_tables(self, ti: ThreadedYielder, table1, table2, level=0, max_rows=None):
222226
assert table1.is_bounded and table2.is_bounded
223227

224228
if max_rows is None:
@@ -242,8 +246,7 @@ def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None):
242246

243247
logger.info(". " * level + f"Diff found {len(diff)} different rows.")
244248
self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2))
245-
yield from diff
246-
return
249+
return diff
247250

248251
# Choose evenly spaced checkpoints (according to min_key and max_key)
249252
checkpoints = table1.choose_checkpoints(self.bisection_factor - 1)
@@ -253,15 +256,10 @@ def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None):
253256
segmented2 = table2.segment_by_checkpoints(checkpoints)
254257

255258
# Recursively compare each pair of corresponding segments between table1 and table2
256-
diff_iters = [
257-
self._diff_tables(t1, t2, level + 1, i + 1, len(segmented1))
258-
for i, (t1, t2) in enumerate(safezip(segmented1, segmented2))
259-
]
260-
261-
for res in self._thread_map(list, diff_iters):
262-
yield from res
259+
for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)):
260+
ti.submit(self._diff_tables, ti, t1, t2, level + 1, i + 1, len(segmented1), priority=level)
263261

264-
def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_count=None):
262+
def _diff_tables(self, ti: ThreadedYielder, table1, table2, level=0, segment_index=None, segment_count=None):
265263
logger.info(
266264
". " * level + f"Diffing segment {segment_index}/{segment_count}, "
267265
f"key-range: {table1.min_key}..{table2.max_key}, "
@@ -275,8 +273,7 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun
275273
if BENCHMARK:
276274
max_rows_from_keys = max(table1.max_key - table1.min_key, table2.max_key - table2.min_key)
277275
if max_rows_from_keys < self.bisection_threshold:
278-
yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max_rows_from_keys)
279-
return
276+
return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max_rows_from_keys)
280277

281278
(count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2])
282279

@@ -293,7 +290,7 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun
293290
self.stats["table2_count"] = self.stats.get("table2_count", 0) + count2
294291

295292
if checksum1 != checksum2:
296-
yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max(count1, count2))
293+
return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max(count1, count2))
297294

298295
def _thread_map(self, func, iterable):
299296
if not self.threaded:

data_diff/thread_utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import itertools
2+
from concurrent.futures.thread import _WorkItem
3+
from queue import PriorityQueue
4+
from collections import deque
5+
from collections.abc import Iterable
6+
from concurrent.futures import ThreadPoolExecutor
7+
from time import sleep
8+
from typing import Callable, Iterator, Optional
9+
10+
11+
class AutoPriorityQueue(PriorityQueue):
12+
"""Overrides PriorityQueue to automatically get the priority from _WorkItem.kwargs
13+
14+
We also assign a unique id for each item, to avoid making comparisons on _WorkItem.
15+
As a side effect, items with the same priority are returned FIFO.
16+
"""
17+
18+
_counter = itertools.count().__next__
19+
20+
def put(self, item: Optional[_WorkItem], block=True, timeout=None):
21+
priority = item.kwargs.pop("priority") if item is not None else 0
22+
super().put((-priority, self._counter(), item), block, timeout)
23+
24+
def get(self, block=True, timeout=None) -> Optional[_WorkItem]:
25+
_p, _c, work_item = super().get(block, timeout)
26+
return work_item
27+
28+
29+
class PriorityThreadPoolExecutor(ThreadPoolExecutor):
30+
"""Overrides ThreadPoolExecutor to use AutoPriorityQueue
31+
32+
XXX WARNING: Might break in future versions of Python
33+
"""
34+
35+
def __init__(self, *args):
36+
super().__init__(*args)
37+
38+
self._work_queue = AutoPriorityQueue()
39+
40+
41+
class ThreadedYielder(Iterable):
42+
"""Yields results from multiple threads into a single iterator, ordered by priority.
43+
44+
To add a source iterator, call ``submit()`` with a function that returns an iterator.
45+
Priority for the iterator can be provided via the keyword argument 'priority'. (higher runs first)
46+
"""
47+
48+
def __init__(self, max_workers: Optional[int] = None):
49+
self._pool = PriorityThreadPoolExecutor(max_workers)
50+
self._futures = deque()
51+
self._yield = deque()
52+
self._exception = None
53+
54+
def _worker(self, fn, *args, **kwargs):
55+
try:
56+
res = fn(*args, **kwargs)
57+
if res is not None:
58+
self._yield += res
59+
except Exception as e:
60+
self._exception = e
61+
62+
def submit(self, fn: Callable, *args, priority: int = 0, **kwargs):
63+
self._futures.append(self._pool.submit(self._worker, fn, *args, priority=priority, **kwargs))
64+
65+
def __iter__(self) -> Iterator:
66+
while True:
67+
if self._exception:
68+
raise self._exception
69+
70+
while self._yield:
71+
yield self._yield.popleft()
72+
73+
if not self._futures:
74+
# No more tasks
75+
return
76+
77+
if self._futures[0].done():
78+
self._futures.popleft()
79+
else:
80+
sleep(0.001)

0 commit comments

Comments
 (0)