Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit daf2d94

Browse files
committed
Adjustments for PR: Use DiffTestCase, remove unneeded tests, and style fixes
1 parent a9be91b commit daf2d94

File tree

4 files changed

+35
-109
lines changed

4 files changed

+35
-109
lines changed

data_diff/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,6 @@ def _main(
317317
logging.error(e)
318318
return
319319

320-
321320
now: datetime = db1.query(current_timestamp(), datetime)
322321
now = now.replace(tzinfo=None)
323322
try:
@@ -405,6 +404,7 @@ def _main(
405404
diff_iter = differ.diff_tables(*segments)
406405

407406
if limit:
407+
assert not stats
408408
diff_iter = islice(diff_iter, int(limit))
409409

410410
if stats:

data_diff/diff_tables.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from concurrent.futures import ThreadPoolExecutor, as_completed
1212

1313
from runtype import dataclass
14-
from dataclasses import field
1514

1615
from data_diff.info_tree import InfoTree, SegmentInfo
1716

@@ -93,7 +92,7 @@ class DiffResultWrapper:
9392
diff: iter # DiffResult
9493
info_tree: InfoTree
9594
stats: dict
96-
result_list: list = field(default_factory=list)
95+
result_list: list = []
9796

9897
def __iter__(self):
9998
yield from self.result_list
@@ -102,8 +101,8 @@ def __iter__(self):
102101
yield i
103102

104103
def _get_stats(self) -> DiffStats:
104+
list(self) # Consume the iterator into result_list, if we haven't already
105105

106-
list(self)
107106
diff_by_key = {}
108107
for sign, values in self.result_list:
109108
k = values[: len(self.info_tree.info.tables[0].key_columns)]
@@ -179,7 +178,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment, info_tree: Inf
179178
"""
180179
if info_tree is None:
181180
info_tree = InfoTree(SegmentInfo([table1, table2]))
182-
return DiffResultWrapper(self._diff_tables_wrapper(table1, table2, info_tree), info_tree, self.stats, [])
181+
return DiffResultWrapper(self._diff_tables_wrapper(table1, table2, info_tree), info_tree, self.stats)
183182

184183
def _diff_tables_wrapper(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree) -> DiffResult:
185184
if is_tracking_enabled():

tests/test_api.py

Lines changed: 15 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,22 @@
1-
import unittest
2-
import io
3-
import unittest.mock
41
import arrow
52
from datetime import datetime
63

74
from data_diff import diff_tables, connect_to_table
85
from data_diff.databases import MySQL
96
from data_diff.sqeleton.queries import table, commit
107

11-
from .common import TEST_MYSQL_CONN_STRING, get_conn, random_table_suffix
8+
from .common import TEST_MYSQL_CONN_STRING, get_conn, random_table_suffix, DiffTestCase
129

1310

14-
def _commit(conn):
15-
conn.query(commit)
11+
class TestApi(DiffTestCase):
12+
src_schema = {"id": int, "datetime": datetime, "text_comment": str}
13+
db_cls = MySQL
1614

17-
18-
class TestApi(unittest.TestCase):
1915
def setUp(self) -> None:
20-
self.conn = get_conn(MySQL)
21-
suffix = random_table_suffix()
22-
self.table_src_name = f"test_api{suffix}"
23-
self.table_dst_name = f"test_api_2{suffix}"
24-
25-
self.table_src = table(self.table_src_name)
26-
self.table_dst = table(self.table_dst_name)
16+
super().setUp()
2717

28-
self.conn.query(self.table_src.drop(True))
29-
self.conn.query(self.table_dst.drop(True))
18+
self.conn = self.connection
3019

31-
src_table = table(self.table_src_name, schema={"id": int, "datetime": datetime, "text_comment": str})
32-
self.conn.query(src_table.create())
3320
self.now = now = arrow.get()
3421

3522
rows = [
@@ -39,21 +26,14 @@ def setUp(self) -> None:
3926
(self.now.shift(seconds=-6), "c"),
4027
]
4128

42-
self.conn.query(src_table.insert_rows((i, ts.datetime, s) for i, (ts, s) in enumerate(rows)))
43-
_commit(self.conn)
44-
45-
self.conn.query(self.table_dst.create(self.table_src))
46-
_commit(self.conn)
47-
48-
self.conn.query(src_table.insert_row(len(rows), self.now.shift(seconds=-3).datetime, "3 seconds ago"))
49-
_commit(self.conn)
50-
51-
def tearDown(self) -> None:
52-
self.conn.query(self.table_src.drop(True))
53-
self.conn.query(self.table_dst.drop(True))
54-
_commit(self.conn)
55-
56-
return super().tearDown()
29+
self.conn.query(
30+
[
31+
self.src_table.insert_rows((i, ts.datetime, s) for i, (ts, s) in enumerate(rows)),
32+
self.dst_table.create(self.src_table),
33+
self.src_table.insert_row(len(rows), self.now.shift(seconds=-3).datetime, "3 seconds ago"),
34+
commit,
35+
]
36+
)
5737

5838
def test_api(self):
5939
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name)
@@ -76,21 +56,8 @@ def test_api(self):
7656
t1.database.close()
7757
t2.database.close()
7858

79-
def test_api_get_stats_string(self):
80-
expected_string = "5 rows in table A\n4 rows in table B\n1 rows exclusive to table A (not present in B)\n0 rows exclusive to table B (not present in A)\n0 rows updated\n4 rows unchanged\n20.00% difference score\n\nExtra-Info:\n rows_downloaded = 5\n"
81-
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name)
82-
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_dst_name)
83-
diff = diff_tables(t1, t2)
84-
output = diff.get_stats_string()
85-
86-
self.assertEqual(expected_string, output)
87-
self.assertIsNotNone(diff)
88-
assert len(list(diff)) == 1
89-
90-
t1.database.close()
91-
t2.database.close()
92-
9359
def test_api_get_stats_dict(self):
60+
# XXX Likely to change in the future
9461
expected_dict = {
9562
"rows_A": 5,
9663
"rows_B": 4,

tests/test_cli.py

Lines changed: 16 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
import logging
2-
import unittest
3-
import arrow
42
import subprocess
53
import sys
64
from datetime import datetime, timedelta
75

86
from data_diff.databases import MySQL
9-
from data_diff.sqeleton.queries import table, commit
7+
from data_diff.sqeleton.queries import commit
108

11-
from .common import TEST_MYSQL_CONN_STRING, get_conn, random_table_suffix
12-
13-
14-
def _commit(conn):
15-
conn.query(commit)
9+
from .common import TEST_MYSQL_CONN_STRING, DiffTestCase
1610

1711

1812
def run_datadiff_cli(*args):
@@ -26,23 +20,14 @@ def run_datadiff_cli(*args):
2620
return stdout.splitlines()
2721

2822

29-
class TestCLI(unittest.TestCase):
30-
def setUp(self) -> None:
31-
self.conn = get_conn(MySQL)
32-
33-
suffix = random_table_suffix()
34-
self.table_src_name = f"test_api{suffix}"
35-
self.table_dst_name = f"test_api_2{suffix}"
23+
class TestCLI(DiffTestCase):
24+
db_cls = MySQL
25+
src_schema = {"id": int, "datetime": datetime, "text_comment": str}
3626

37-
self.table_src = table(self.table_src_name)
38-
self.table_dst = table(self.table_dst_name)
39-
self.conn.query(self.table_src.drop(True))
40-
self.conn.query(self.table_dst.drop(True))
27+
def setUp(self) -> None:
28+
super().setUp()
4129

42-
src_table = table(self.table_src_name, schema={"id": int, "datetime": datetime, "text_comment": str})
43-
self.conn.query(src_table.create())
44-
self.conn.query("SET @@session.time_zone='+00:00'")
45-
now = self.conn.query("select now()", datetime)
30+
now = self.connection.query("select now()", datetime)
4631

4732
rows = [
4833
(now, "now"),
@@ -51,21 +36,14 @@ def setUp(self) -> None:
5136
(now - timedelta(seconds=6), "c"),
5237
]
5338

54-
self.conn.query(src_table.insert_rows((i, ts, s) for i, (ts, s) in enumerate(rows)))
55-
_commit(self.conn)
56-
57-
self.conn.query(self.table_dst.create(self.table_src))
58-
_commit(self.conn)
59-
60-
self.conn.query(src_table.insert_row(len(rows), now - timedelta(seconds=3), "3 seconds ago"))
61-
_commit(self.conn)
62-
63-
def tearDown(self) -> None:
64-
self.conn.query(self.table_src.drop(True))
65-
self.conn.query(self.table_dst.drop(True))
66-
_commit(self.conn)
67-
68-
return super().tearDown()
39+
self.connection.query(
40+
[
41+
self.src_table.insert_rows((i, ts, s) for i, (ts, s) in enumerate(rows)),
42+
self.dst_table.create(self.src_table),
43+
self.src_table.insert_row(len(rows), now - timedelta(seconds=3), "3 seconds ago"),
44+
commit,
45+
]
46+
)
6947

7048
def test_basic(self):
7149
diff = run_datadiff_cli(
@@ -91,21 +69,3 @@ def test_options(self):
9169
"1h",
9270
)
9371
assert len(diff) == 1
94-
95-
def test_stats(self):
96-
diff_output = run_datadiff_cli(
97-
TEST_MYSQL_CONN_STRING, self.table_src_name, TEST_MYSQL_CONN_STRING, self.table_dst_name, "-s"
98-
)
99-
assert len(diff_output) == 11
100-
101-
def test_stats_json(self):
102-
diff_output = run_datadiff_cli(
103-
TEST_MYSQL_CONN_STRING, self.table_src_name, TEST_MYSQL_CONN_STRING, self.table_dst_name, "-s", "--json"
104-
)
105-
assert len(diff_output) == 2
106-
107-
def test_stats_no_diff(self):
108-
diff_output = run_datadiff_cli(
109-
TEST_MYSQL_CONN_STRING, self.table_src_name, TEST_MYSQL_CONN_STRING, self.table_src_name, "-s"
110-
)
111-
assert len(diff_output) == 11

0 commit comments

Comments
 (0)