Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 67171cb

Browse files
authored
Merge pull request #243 from datafold/next_master
Merge of #238 and #235
2 parents cb24ac9 + ba5eabf commit 67171cb

File tree

9 files changed

+273
-91
lines changed

9 files changed

+273
-91
lines changed

data_diff/__main__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ def _main(
197197
bisection_threshold=bisection_threshold,
198198
threaded=threaded,
199199
max_threadpool_size=threads and threads * 2,
200-
debug=debug,
201200
)
202201

203202
if database1 is None or database2 is None:

data_diff/databases/base.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
Float,
1717
ColType_UUID,
1818
Native_UUID,
19-
String_Alphanum,
2019
String_UUID,
20+
String_Alphanum,
21+
String_FixedAlphanum,
22+
String_VaryingAlphanum,
2123
TemporalType,
2224
UnknownColType,
2325
Text,
@@ -79,6 +81,7 @@ class Database(AbstractDatabase):
7981

8082
TYPE_CLASSES: Dict[str, type] = {}
8183
default_schema: str = None
84+
SUPPORTS_ALPHANUMS = True
8285

8386
@property
8487
def name(self):
@@ -229,23 +232,22 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe
229232
col_dict[col_name] = String_UUID()
230233
continue
231234

232-
alphanum_samples = [s for s in samples if s and String_Alphanum.test_value(s)]
233-
if alphanum_samples:
234-
if len(alphanum_samples) != len(samples):
235-
logger.warning(
236-
f"Mixed Alphanum/Non-Alphanum values detected in column {'.'.join(table_path)}.{col_name}, disabling Alphanum support."
237-
)
238-
else:
239-
assert col_name in col_dict
240-
lens = set(map(len, alphanum_samples))
241-
if len(lens) > 1:
235+
if self.SUPPORTS_ALPHANUMS: # Anything but MySQL (so far)
236+
alphanum_samples = [s for s in samples if String_Alphanum.test_value(s)]
237+
if alphanum_samples:
238+
if len(alphanum_samples) != len(samples):
242239
logger.warning(
243-
f"Mixed Alphanum lengths detected in column {'.'.join(table_path)}.{col_name}, disabling Alphanum support."
240+
f"Mixed Alphanum/Non-Alphanum values detected in column {'.'.join(table_path)}.{col_name}. It cannot be used as a key."
244241
)
245242
else:
246-
(length,) = lens
247-
col_dict[col_name] = String_Alphanum(length=length)
248-
continue
243+
assert col_name in col_dict
244+
lens = set(map(len, alphanum_samples))
245+
if len(lens) > 1:
246+
col_dict[col_name] = String_VaryingAlphanum()
247+
else:
248+
(length,) = lens
249+
col_dict[col_name] = String_FixedAlphanum(length=length)
250+
continue
249251

250252
# @lru_cache()
251253
# def get_table_schema(self, path: DbPath) -> Dict[str, ColType]:

data_diff/databases/database_types.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,7 @@ class String_UUID(StringType, ColType_UUID):
9292
pass
9393

9494

95-
@dataclass
9695
class String_Alphanum(StringType, ColType_Alphanum):
97-
length: int
98-
9996
@staticmethod
10097
def test_value(value: str) -> bool:
10198
try:
@@ -104,6 +101,18 @@ def test_value(value: str) -> bool:
104101
except ValueError:
105102
return False
106103

104+
def make_value(self, value):
105+
return self.python_type(value)
106+
107+
108+
class String_VaryingAlphanum(String_Alphanum):
109+
pass
110+
111+
112+
@dataclass
113+
class String_FixedAlphanum(String_Alphanum):
114+
length: int
115+
107116
def make_value(self, value):
108117
if len(value) != self.length:
109118
raise ValueError(f"Expected alphanumeric value of length {self.length}, but got '{value}'.")

data_diff/databases/mysql.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class MySQL(ThreadedDatabase):
2828
"binary": Text,
2929
}
3030
ROUNDS_ON_PREC_LOSS = True
31+
SUPPORTS_ALPHANUMS = False
3132

3233
def __init__(self, *, thread_count, **kw):
3334
self._args = kw

data_diff/diff_tables.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from runtype import dataclass
1414

1515
from .utils import safezip, run_as_daemon
16+
from .thread_utils import ThreadedYielder
1617
from .databases.database_types import IKey, NumericType, PrecisionType, StringType
1718
from .table_segment import TableSegment
1819
from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled
@@ -121,22 +122,25 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
121122
logger.info(
122123
f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. "
123124
f"key-range: {table1.min_key}..{table2.max_key}, "
124-
f"size: {table1.approximate_size()}"
125+
f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}"
125126
)
126127

128+
ti = ThreadedYielder(self.max_threadpool_size)
127129
# Bisect (split) the table into segments, and diff them recursively.
128-
yield from self._bisect_and_diff_tables(table1, table2)
130+
ti.submit(self._bisect_and_diff_tables, ti, table1, table2)
129131

130132
# Now we check for the second min-max, to diff the portions we "missed".
131133
min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges))
132134

133135
if min_key2 < min_key1:
134136
pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)]
135-
yield from self._bisect_and_diff_tables(*pre_tables)
137+
ti.submit(self._bisect_and_diff_tables, ti, *pre_tables)
136138

137139
if max_key2 > max_key1:
138140
post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)]
139-
yield from self._bisect_and_diff_tables(*post_tables)
141+
ti.submit(self._bisect_and_diff_tables, ti, *post_tables)
142+
143+
yield from ti
140144

141145
except BaseException as e: # Catch KeyboardInterrupt too
142146
error = e
@@ -218,12 +222,12 @@ def _validate_and_adjust_columns(self, table1, table2):
218222
"If encoding/formatting differs between databases, it may result in false positives."
219223
)
220224

221-
def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None):
225+
def _bisect_and_diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, level=0, max_rows=None):
222226
assert table1.is_bounded and table2.is_bounded
223227

224228
if max_rows is None:
225229
# We can be sure that row_count <= max_rows
226-
max_rows = table1.max_key - table1.min_key
230+
max_rows = max(table1.approximate_size(), table2.approximate_size())
227231

228232
# If count is below the threshold, just download and compare the columns locally
229233
# This saves time, as bisection speed is limited by ping and query performance.
@@ -242,8 +246,7 @@ def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None):
242246

243247
logger.info(". " * level + f"Diff found {len(diff)} different rows.")
244248
self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2))
245-
yield from diff
246-
return
249+
return diff
247250

248251
# Choose evenly spaced checkpoints (according to min_key and max_key)
249252
checkpoints = table1.choose_checkpoints(self.bisection_factor - 1)
@@ -253,38 +256,31 @@ def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None):
253256
segmented2 = table2.segment_by_checkpoints(checkpoints)
254257

255258
# Recursively compare each pair of corresponding segments between table1 and table2
256-
diff_iters = [
257-
self._diff_tables(t1, t2, level + 1, i + 1, len(segmented1))
258-
for i, (t1, t2) in enumerate(safezip(segmented1, segmented2))
259-
]
259+
for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)):
260+
ti.submit(self._diff_tables, ti, t1, t2, max_rows, level + 1, i + 1, len(segmented1), priority=level)
260261

261-
for res in self._thread_map(list, diff_iters):
262-
yield from res
263-
264-
def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_count=None):
262+
def _diff_tables(self, ti: ThreadedYielder, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None):
265263
logger.info(
266264
". " * level + f"Diffing segment {segment_index}/{segment_count}, "
267265
f"key-range: {table1.min_key}..{table2.max_key}, "
268-
f"size: {table2.max_key-table1.min_key}"
266+
f"size <= {max_rows}"
269267
)
270268

271269
# When benchmarking, we want the ability to skip checksumming. This
272270
# allows us to download all rows for comparison in performance. By
273271
# default, data-diff will checksum the section first (when it's below
274272
# the threshold) and _then_ download it.
275273
if BENCHMARK:
276-
max_rows_from_keys = max(table1.max_key - table1.min_key, table2.max_key - table2.min_key)
277-
if max_rows_from_keys < self.bisection_threshold:
278-
yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max_rows_from_keys)
279-
return
274+
if max_rows < self.bisection_threshold:
275+
return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max_rows)
280276

281277
(count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2])
282278

283279
if count1 == 0 and count2 == 0:
284-
logger.warning(
285-
"Uneven distribution of keys detected. (big gaps in the key column). "
286-
"For better performance, we recommend to increase the bisection-threshold."
287-
)
280+
# logger.warning(
281+
# f"Uneven distribution of keys detected in segment {table1.min_key}..{table2.max_key}. (big gaps in the key column). "
282+
# "For better performance, we recommend to increase the bisection-threshold."
283+
# )
288284
assert checksum1 is None and checksum2 is None
289285
return
290286

@@ -293,7 +289,7 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun
293289
self.stats["table2_count"] = self.stats.get("table2_count", 0) + count2
294290

295291
if checksum1 != checksum2:
296-
yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max(count1, count2))
292+
return self._bisect_and_diff_tables(ti, table1, table2, level=level, max_rows=max(count1, count2))
297293

298294
def _thread_map(self, func, iterable):
299295
if not self.threaded:

data_diff/table_segment.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from runtype import dataclass
66

7-
from .utils import ArithString, split_space
7+
from .utils import ArithString, split_space, ArithAlphanumeric
88

99
from .databases.base import Database
1010
from .databases.database_types import DbPath, DbKey, DbTime, Native_UUID, Schema, create_schema
@@ -149,8 +149,9 @@ def choose_checkpoints(self, count: int) -> List[DbKey]:
149149
assert self.is_bounded
150150
if isinstance(self.min_key, ArithString):
151151
assert type(self.min_key) is type(self.max_key)
152-
checkpoints = split_space(self.min_key.int, self.max_key.int, count)
153-
return [self.min_key.new(int=i) for i in checkpoints]
152+
checkpoints = self.min_key.range(self.max_key, count)
153+
assert all(self.min_key <= x <= self.max_key for x in checkpoints)
154+
return checkpoints
154155

155156
return split_space(self.min_key, self.max_key, count)
156157

data_diff/thread_utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import itertools
2+
from concurrent.futures.thread import _WorkItem
3+
from queue import PriorityQueue
4+
from collections import deque
5+
from collections.abc import Iterable
6+
from concurrent.futures import ThreadPoolExecutor
7+
from time import sleep
8+
from typing import Callable, Iterator, Optional
9+
10+
11+
class AutoPriorityQueue(PriorityQueue):
12+
"""Overrides PriorityQueue to automatically get the priority from _WorkItem.kwargs
13+
14+
We also assign a unique id for each item, to avoid making comparisons on _WorkItem.
15+
As a side effect, items with the same priority are returned FIFO.
16+
"""
17+
18+
_counter = itertools.count().__next__
19+
20+
def put(self, item: Optional[_WorkItem], block=True, timeout=None):
21+
priority = item.kwargs.pop("priority") if item is not None else 0
22+
super().put((-priority, self._counter(), item), block, timeout)
23+
24+
def get(self, block=True, timeout=None) -> Optional[_WorkItem]:
25+
_p, _c, work_item = super().get(block, timeout)
26+
return work_item
27+
28+
29+
class PriorityThreadPoolExecutor(ThreadPoolExecutor):
30+
"""Overrides ThreadPoolExecutor to use AutoPriorityQueue
31+
32+
XXX WARNING: Might break in future versions of Python
33+
"""
34+
35+
def __init__(self, *args):
36+
super().__init__(*args)
37+
38+
self._work_queue = AutoPriorityQueue()
39+
40+
41+
class ThreadedYielder(Iterable):
42+
"""Yields results from multiple threads into a single iterator, ordered by priority.
43+
44+
To add a source iterator, call ``submit()`` with a function that returns an iterator.
45+
Priority for the iterator can be provided via the keyword argument 'priority'. (higher runs first)
46+
"""
47+
48+
def __init__(self, max_workers: Optional[int] = None):
49+
self._pool = PriorityThreadPoolExecutor(max_workers)
50+
self._futures = deque()
51+
self._yield = deque()
52+
self._exception = None
53+
54+
def _worker(self, fn, *args, **kwargs):
55+
try:
56+
res = fn(*args, **kwargs)
57+
if res is not None:
58+
self._yield += res
59+
except Exception as e:
60+
self._exception = e
61+
62+
def submit(self, fn: Callable, *args, priority: int = 0, **kwargs):
63+
self._futures.append(self._pool.submit(self._worker, fn, *args, priority=priority, **kwargs))
64+
65+
def __iter__(self) -> Iterator:
66+
while True:
67+
if self._exception:
68+
raise self._exception
69+
70+
while self._yield:
71+
yield self._yield.popleft()
72+
73+
if not self._futures:
74+
# No more tasks
75+
return
76+
77+
if self._futures[0].done():
78+
self._futures.popleft()
79+
else:
80+
sleep(0.001)

0 commit comments

Comments
 (0)