Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 6ec44a2

Browse files
authored
Merge pull request #227 from datafold/sep6
Refactors and fixes
2 parents 60b050f + d411dfc commit 6ec44a2

File tree

7 files changed

+54
-16
lines changed

7 files changed

+54
-16
lines changed

data_diff/__main__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
from .utils import remove_password_from_url, safezip, match_like
1313
from .diff_tables import TableDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
14-
from .table_segment import create_schema, TableSegment
14+
from .table_segment import TableSegment
15+
from .databases.database_types import create_schema
1516
from .databases.connect import connect
1617
from .parse_time import parse_time_before_now, UNITS_STR, ParseError
1718
from .config import apply_config_from_file

data_diff/databases/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,13 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None
270270
return f"LIMIT {limit}"
271271

272272
def concat(self, l: List[str]) -> str:
273+
assert len(l) > 1
273274
joined_exprs = ", ".join(l)
274275
return f"concat({joined_exprs})"
275276

277+
def is_distinct_from(self, a: str, b: str) -> str:
278+
return f"{a} is distinct from {b}"
279+
276280
def timestamp_value(self, t: DbTime) -> str:
277281
return f"'{t.isoformat()}'"
278282

data_diff/databases/database_types.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1+
import logging
12
import decimal
23
from abc import ABC, abstractmethod
34
from typing import Sequence, Optional, Tuple, Union, Dict, List
45
from datetime import datetime
56

67
from runtype import dataclass
78

8-
from data_diff.utils import ArithAlphanumeric, ArithUUID, CaseAwareMapping
9+
from data_diff.utils import ArithAlphanumeric, ArithUUID, CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict
910

1011

1112
DbPath = Tuple[str, ...]
1213
DbKey = Union[int, str, bytes, ArithUUID, ArithAlphanumeric]
1314
DbTime = datetime
1415

16+
logger = logging.getLogger("databases")
17+
1518

1619
class ColType:
1720
supported = True
@@ -129,6 +132,8 @@ class UnknownColType(ColType):
129132

130133

131134
class AbstractDatabase(ABC):
135+
name: str
136+
132137
@abstractmethod
133138
def quote(self, s: str):
134139
"Quote SQL name (implementation specific)"
@@ -144,6 +149,11 @@ def concat(self, l: List[str]) -> str:
144149
"Provide SQL for concatenating a bunch of column into a string"
145150
...
146151

152+
@abstractmethod
153+
def is_distinct_from(self, a: str, b: str) -> str:
154+
"Provide SQL for a comparison where NULL = NULL is true"
155+
...
156+
147157
@abstractmethod
148158
def timestamp_value(self, t: DbTime) -> str:
149159
"Provide SQL for the given timestamp value"
@@ -270,3 +280,15 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
270280

271281

272282
Schema = CaseAwareMapping
283+
284+
285+
def create_schema(db: AbstractDatabase, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping:
286+
logger.debug(f"[{db.name}] Schema = {schema}")
287+
288+
if case_sensitive:
289+
return CaseSensitiveDict(schema)
290+
291+
if len({k.lower() for k in schema}) < len(schema):
292+
logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}')
293+
logger.warning("We recommend to disable case-insensitivity (set --case-sensitive).")
294+
return CaseInsensitiveDict(schema)

data_diff/databases/mysql.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,6 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
6969

7070
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
7171
return f"TRIM(CAST({value} AS char))"
72+
73+
def is_distinct_from(self, a: str, b: str) -> str:
74+
return f"not ({a} <=> {b})"

data_diff/table_segment.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,17 @@
44

55
from runtype import dataclass
66

7-
from .utils import ArithString, split_space, CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict
7+
from .utils import ArithString, split_space
88

99
from .databases.base import Database
10-
from .databases.database_types import DbPath, DbKey, DbTime, Native_UUID, Schema
10+
from .databases.database_types import DbPath, DbKey, DbTime, Native_UUID, Schema, create_schema
1111
from .sql import Select, Checksum, Compare, Count, TableName, Time, Value
1212

1313
logger = logging.getLogger("table_segment")
1414

1515
RECOMMENDED_CHECKSUM_DURATION = 10
1616

1717

18-
def create_schema(db: Database, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping:
19-
logger.debug(f"[{db.name}] Schema = {schema}")
20-
21-
if case_sensitive:
22-
return CaseSensitiveDict(schema)
23-
24-
if len({k.lower() for k in schema}) < len(schema):
25-
logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}')
26-
logger.warning("We recommend to disable case-insensitivity (remove --any-case).")
27-
return CaseInsensitiveDict(schema)
28-
29-
3018
@dataclass
3119
class TableSegment:
3220
"""Signifies a segment of rows (and selected columns) within a table

data_diff/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def items(self) -> Iterable[Tuple[str, V]]:
212212

213213
class CaseSensitiveDict(dict, CaseAwareMapping):
214214
def get_key(self, key):
215+
self[key] # Throw KeyError is key doesn't exist
215216
return key
216217

217218
def as_insensitive(self):

tests/test_diff_tables.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,25 @@ def test_table_segment(self):
521521

522522
self.assertRaises(ValueError, self.table.replace, min_key=10, max_key=0)
523523

524+
def test_case_awareness(self):
525+
# create table
526+
self.connection.query(f"create table {self.table_src}(id int, userid int, timestamp timestamp)", None)
527+
_commit(self.connection)
528+
529+
# insert rows
530+
cols = "id userid timestamp".split()
531+
time = "2022-01-01 00:00:00.000000"
532+
time_str = f"timestamp '{time}'"
533+
_insert_rows(self.connection, self.table_src, cols, [[1, 9, time_str], [2, 2, time_str]])
534+
_commit(self.connection)
535+
536+
res = tuple(self.table.replace(key_column="Id", case_sensitive=False).with_schema().query_key_range())
537+
assert res == ("1", "2")
538+
539+
self.assertRaises(
540+
KeyError, self.table.replace(key_column="Id", case_sensitive=True).with_schema().query_key_range
541+
)
542+
524543

525544
@test_per_database
526545
class TestTableUUID(TestPerDatabase):

0 commit comments

Comments
 (0)