Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit a40a226

Browse files
committed
Refactored create_schema() into database_types
1 parent 60b050f commit a40a226

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

data_diff/__main__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
from .utils import remove_password_from_url, safezip, match_like
1313
from .diff_tables import TableDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
14-
from .table_segment import create_schema, TableSegment
14+
from .table_segment import TableSegment
15+
from .databases.database_types import create_schema
1516
from .databases.connect import connect
1617
from .parse_time import parse_time_before_now, UNITS_STR, ParseError
1718
from .config import apply_config_from_file

data_diff/databases/database_types.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1+
import logging
12
import decimal
23
from abc import ABC, abstractmethod
34
from typing import Sequence, Optional, Tuple, Union, Dict, List
45
from datetime import datetime
56

67
from runtype import dataclass
78

8-
from data_diff.utils import ArithAlphanumeric, ArithUUID, CaseAwareMapping
9+
from data_diff.utils import ArithAlphanumeric, ArithUUID, CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict
910

1011

1112
DbPath = Tuple[str, ...]
1213
DbKey = Union[int, str, bytes, ArithUUID, ArithAlphanumeric]
1314
DbTime = datetime
1415

16+
logger = logging.getLogger("databases")
17+
1518

1619
class ColType:
1720
supported = True
@@ -129,6 +132,8 @@ class UnknownColType(ColType):
129132

130133

131134
class AbstractDatabase(ABC):
135+
name: str
136+
132137
@abstractmethod
133138
def quote(self, s: str):
134139
"Quote SQL name (implementation specific)"
@@ -270,3 +275,15 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
270275

271276

272277
Schema = CaseAwareMapping
278+
279+
280+
def create_schema(db: AbstractDatabase, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping:
281+
logger.debug(f"[{db.name}] Schema = {schema}")
282+
283+
if case_sensitive:
284+
return CaseSensitiveDict(schema)
285+
286+
if len({k.lower() for k in schema}) < len(schema):
287+
logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}')
288+
logger.warning("We recommend to disable case-insensitivity (set --case-sensitive).")
289+
return CaseInsensitiveDict(schema)

data_diff/table_segment.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,17 @@
44

55
from runtype import dataclass
66

7-
from .utils import ArithString, split_space, CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict
7+
from .utils import ArithString, split_space
88

99
from .databases.base import Database
10-
from .databases.database_types import DbPath, DbKey, DbTime, Native_UUID, Schema
10+
from .databases.database_types import DbPath, DbKey, DbTime, Native_UUID, Schema, create_schema
1111
from .sql import Select, Checksum, Compare, Count, TableName, Time, Value
1212

1313
logger = logging.getLogger("table_segment")
1414

1515
RECOMMENDED_CHECKSUM_DURATION = 10
1616

1717

18-
def create_schema(db: Database, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping:
19-
logger.debug(f"[{db.name}] Schema = {schema}")
20-
21-
if case_sensitive:
22-
return CaseSensitiveDict(schema)
23-
24-
if len({k.lower() for k in schema}) < len(schema):
25-
logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}')
26-
logger.warning("We recommend to disable case-insensitivity (remove --any-case).")
27-
return CaseInsensitiveDict(schema)
28-
29-
3018
@dataclass
3119
class TableSegment:
3220
"""Signifies a segment of rows (and selected columns) within a table

0 commit comments

Comments
 (0)