Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 9565b99

Browse files
dlawinerezsh
authored andcommitted
enable repeated iteration of diff
1 parent 62e15c0 commit 9565b99

File tree

4 files changed

+26
-38
lines changed

4 files changed

+26
-38
lines changed

data_diff/__main__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,6 @@ def _main(
408408
diff_iter = islice(diff_iter, int(limit))
409409

410410
if stats:
411-
# required to create this variable before get_stats
412-
diff_list = list(diff_iter)
413411
if json_output:
414412
rich.print(json.dumps(diff_iter.get_stats_dict()))
415413
else:

data_diff/diff_tables.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -94,34 +94,31 @@ class DiffResultWrapper:
9494
result_list: list = field(default_factory=list)
9595

9696
def __iter__(self):
97+
yield from self.result_list
9798
for i in self.diff:
9899
self.result_list.append(i)
99100
yield i
100101

101102
def _get_stats(self) -> DiffStats:
102103

104+
list(self)
103105
diff_by_key = {}
104-
if len(self.result_list) > 0:
105-
for sign, values in self.result_list:
106-
k = values[: len(self.info_tree.info.tables[0].key_columns)]
107-
if k in diff_by_key:
108-
assert sign != diff_by_key[k]
109-
diff_by_key[k] = "!"
110-
else:
111-
diff_by_key[k] = sign
112-
113-
diff_by_sign = {k: 0 for k in "+-!"}
114-
for sign in diff_by_key.values():
115-
diff_by_sign[sign] += 1
116-
117-
table1_count = self.info_tree.info.rowcounts[1]
118-
table2_count = self.info_tree.info.rowcounts[2]
119-
unchanged = table1_count - diff_by_sign["-"] - diff_by_sign["!"]
120-
diff_percent = 1 - unchanged / max(table1_count, table2_count)
121-
else:
122-
raise RuntimeError(
123-
"result_list is empty, consume the diff iterator to populate values: e.g. \ndiff_iter = diff_tables(...) \ndiff_list = list(diff_iter) \ndiff_iter.print_stats(json_output)"
124-
)
106+
for sign, values in self.result_list:
107+
k = values[: len(self.info_tree.info.tables[0].key_columns)]
108+
if k in diff_by_key:
109+
assert sign != diff_by_key[k]
110+
diff_by_key[k] = "!"
111+
else:
112+
diff_by_key[k] = sign
113+
114+
diff_by_sign = {k: 0 for k in "+-!"}
115+
for sign in diff_by_key.values():
116+
diff_by_sign[sign] += 1
117+
118+
table1_count = self.info_tree.info.rowcounts[1]
119+
table2_count = self.info_tree.info.rowcounts[2]
120+
unchanged = table1_count - diff_by_sign["-"] - diff_by_sign["!"]
121+
diff_percent = 1 - unchanged / max(table1_count, table2_count)
125122

126123
return DiffStats(diff_by_sign, table1_count, table2_count, unchanged, diff_percent)
127124

tests/test_api.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,11 @@ def test_api_get_stats_string(self):
8181
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name)
8282
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_dst_name)
8383
diff = diff_tables(t1, t2)
84-
diff_list = list(diff)
8584
output = diff.get_stats_string()
8685

8786
self.assertEqual(expected_string, output)
8887
self.assertIsNotNone(diff)
89-
assert len(diff_list) == 1
88+
assert len(list(diff)) == 1
9089

9190
t1.database.close()
9291
t2.database.close()
@@ -96,23 +95,11 @@ def test_api_get_stats_dict(self):
9695
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name)
9796
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_dst_name)
9897
diff = diff_tables(t1, t2)
99-
diff_list = list(diff)
10098
output = diff.get_stats_dict()
10199

102100
self.assertEqual(expected_dict, output)
103101
self.assertIsNotNone(diff)
104-
assert len(diff_list) == 1
105-
106-
t1.database.close()
107-
t2.database.close()
108-
109-
def test_api_print_error(self):
110-
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name)
111-
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, (self.table_dst_name,))
112-
diff = diff_tables(t1, t2)
113-
114-
with self.assertRaises(RuntimeError):
115-
diff.get_stats_string()
102+
assert len(list(diff)) == 1
116103

117104
t1.database.close()
118105
t2.database.close()

tests/test_cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,9 @@ def test_stats_json(self):
101101
TEST_MYSQL_CONN_STRING, self.table_src_name, TEST_MYSQL_CONN_STRING, self.table_dst_name, "-s", "--json"
102102
)
103103
assert len(diff_output) == 2
104+
105+
def test_stats_no_diff(self):
106+
diff_output = run_datadiff_cli(
107+
TEST_MYSQL_CONN_STRING, self.table_src_name, TEST_MYSQL_CONN_STRING, self.table_src_name, "-s"
108+
)
109+
assert len(diff_output) == 11

0 commit comments

Comments
 (0)