Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 3bc08e2

Browse files
authored
Merge pull request #751 from datafold/squash-specialised-mixins
Simplify: Squash database-specialised mixins into their database-specialised dialects
2 parents 312a9c5 + b5a015e commit 3bc08e2

File tree

13 files changed

+409
-536
lines changed

13 files changed

+409
-536
lines changed

data_diff/databases/bigquery.py

Lines changed: 79 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
CHECKSUM_HEXDIGITS,
4141
MD5_HEXDIGITS,
4242
)
43-
from data_diff.databases.base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter, Mixin_RandomSample
43+
from data_diff.databases.base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter
4444

4545

4646
@import_helper(text="Please install BigQuery and configure your google-cloud access.")
@@ -63,13 +63,87 @@ def import_bigquery_service_account_impersonation():
6363

6464

6565
@attrs.define(frozen=False)
66-
class Mixin_MD5(AbstractMixin_MD5):
66+
class Dialect(
67+
BaseDialect, AbstractMixin_Schema, AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_TimeTravel
68+
):
69+
name = "BigQuery"
70+
ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation
71+
TYPE_CLASSES = {
72+
# Dates
73+
"TIMESTAMP": Timestamp,
74+
"DATETIME": Datetime,
75+
# Numbers
76+
"INT64": Integer,
77+
"INT32": Integer,
78+
"NUMERIC": Decimal,
79+
"BIGNUMERIC": Decimal,
80+
"FLOAT64": Float,
81+
"FLOAT32": Float,
82+
"STRING": Text,
83+
"BOOL": Boolean,
84+
"JSON": JSON,
85+
}
86+
TYPE_ARRAY_RE = re.compile(r"ARRAY<(.+)>")
87+
TYPE_STRUCT_RE = re.compile(r"STRUCT<(.+)>")
88+
89+
def random(self) -> str:
90+
return "RAND()"
91+
92+
def quote(self, s: str):
93+
return f"`{s}`"
94+
95+
def to_string(self, s: str):
96+
return f"cast({s} as string)"
97+
98+
def type_repr(self, t) -> str:
99+
try:
100+
return {str: "STRING", float: "FLOAT64"}[t]
101+
except KeyError:
102+
return super().type_repr(t)
103+
104+
def parse_type(
105+
self,
106+
table_path: DbPath,
107+
col_name: str,
108+
type_repr: str,
109+
*args: Any, # pass-through args
110+
**kwargs: Any, # pass-through args
111+
) -> ColType:
112+
col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs)
113+
if isinstance(col_type, UnknownColType):
114+
m = self.TYPE_ARRAY_RE.fullmatch(type_repr)
115+
if m:
116+
item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs)
117+
col_type = Array(item_type=item_type)
118+
119+
# We currently ignore structs' structure, but later can parse it too. Examples:
120+
# - STRUCT<INT64, STRING(10)> (unnamed)
121+
# - STRUCT<foo INT64, bar STRING(10)> (named)
122+
# - STRUCT<foo INT64, bar ARRAY<INT64>> (with complex fields)
123+
# - STRUCT<foo INT64, bar STRUCT<a INT64, b INT64>> (nested)
124+
m = self.TYPE_STRUCT_RE.fullmatch(type_repr)
125+
if m:
126+
col_type = Struct()
127+
128+
return col_type
129+
130+
def to_comparable(self, value: str, coltype: ColType) -> str:
131+
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
132+
if isinstance(coltype, (JSON, Array, Struct)):
133+
return self.normalize_value_by_type(value, coltype)
134+
else:
135+
return super().to_comparable(value, coltype)
136+
137+
def set_timezone_to_utc(self) -> str:
138+
raise NotImplementedError()
139+
140+
def parse_table_name(self, name: str) -> DbPath:
141+
path = parse_table_name(name)
142+
return tuple(i for i in path if i is not None)
143+
67144
def md5_as_int(self, s: str) -> str:
68145
return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS})) as int64) as numeric) - {CHECKSUM_OFFSET}"
69146

70-
71-
@attrs.define(frozen=False)
72-
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
73147
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
74148
if coltype.rounds:
75149
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
@@ -112,9 +186,6 @@ def normalize_struct(self, value: str, _coltype: Struct) -> str:
112186
# match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc.
113187
return f"to_json_string({value})"
114188

115-
116-
@attrs.define(frozen=False)
117-
class Mixin_Schema(AbstractMixin_Schema):
118189
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
119190
return (
120191
table(table_schema, "INFORMATION_SCHEMA", "TABLES")
@@ -126,9 +197,6 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
126197
.select(this.table_name)
127198
)
128199

129-
130-
@attrs.define(frozen=False)
131-
class Mixin_TimeTravel(AbstractMixin_TimeTravel):
132200
def time_travel(
133201
self,
134202
table: Compilable,
@@ -155,86 +223,6 @@ def time_travel(
155223
)
156224

157225

158-
@attrs.define(frozen=False)
159-
class Dialect(
160-
BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue
161-
):
162-
name = "BigQuery"
163-
ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation
164-
TYPE_CLASSES = {
165-
# Dates
166-
"TIMESTAMP": Timestamp,
167-
"DATETIME": Datetime,
168-
# Numbers
169-
"INT64": Integer,
170-
"INT32": Integer,
171-
"NUMERIC": Decimal,
172-
"BIGNUMERIC": Decimal,
173-
"FLOAT64": Float,
174-
"FLOAT32": Float,
175-
"STRING": Text,
176-
"BOOL": Boolean,
177-
"JSON": JSON,
178-
}
179-
TYPE_ARRAY_RE = re.compile(r"ARRAY<(.+)>")
180-
TYPE_STRUCT_RE = re.compile(r"STRUCT<(.+)>")
181-
182-
def random(self) -> str:
183-
return "RAND()"
184-
185-
def quote(self, s: str):
186-
return f"`{s}`"
187-
188-
def to_string(self, s: str):
189-
return f"cast({s} as string)"
190-
191-
def type_repr(self, t) -> str:
192-
try:
193-
return {str: "STRING", float: "FLOAT64"}[t]
194-
except KeyError:
195-
return super().type_repr(t)
196-
197-
def parse_type(
198-
self,
199-
table_path: DbPath,
200-
col_name: str,
201-
type_repr: str,
202-
*args: Any, # pass-through args
203-
**kwargs: Any, # pass-through args
204-
) -> ColType:
205-
col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs)
206-
if isinstance(col_type, UnknownColType):
207-
m = self.TYPE_ARRAY_RE.fullmatch(type_repr)
208-
if m:
209-
item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs)
210-
col_type = Array(item_type=item_type)
211-
212-
# We currently ignore structs' structure, but later can parse it too. Examples:
213-
# - STRUCT<INT64, STRING(10)> (unnamed)
214-
# - STRUCT<foo INT64, bar STRING(10)> (named)
215-
# - STRUCT<foo INT64, bar ARRAY<INT64>> (with complex fields)
216-
# - STRUCT<foo INT64, bar STRUCT<a INT64, b INT64>> (nested)
217-
m = self.TYPE_STRUCT_RE.fullmatch(type_repr)
218-
if m:
219-
col_type = Struct()
220-
221-
return col_type
222-
223-
def to_comparable(self, value: str, coltype: ColType) -> str:
224-
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
225-
if isinstance(coltype, (JSON, Array, Struct)):
226-
return self.normalize_value_by_type(value, coltype)
227-
else:
228-
return super().to_comparable(value, coltype)
229-
230-
def set_timezone_to_utc(self) -> str:
231-
raise NotImplementedError()
232-
233-
def parse_table_name(self, name: str) -> DbPath:
234-
path = parse_table_name(name)
235-
return tuple(i for i in path if i is not None)
236-
237-
238226
@attrs.define(frozen=False, init=False, kw_only=True)
239227
class BigQuery(Database):
240228
CONNECT_URI_HELP = "bigquery://<project>/<dataset>"

data_diff/databases/clickhouse.py

Lines changed: 62 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
ThreadedDatabase,
1212
import_helper,
1313
ConnectError,
14-
Mixin_RandomSample,
1514
)
1615
from data_diff.abcs.database_types import (
1716
ColType,
@@ -39,16 +38,74 @@ def import_clickhouse():
3938

4039

4140
@attrs.define(frozen=False)
42-
class Mixin_MD5(AbstractMixin_MD5):
41+
class Dialect(BaseDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
42+
name = "Clickhouse"
43+
ROUNDS_ON_PREC_LOSS = False
44+
TYPE_CLASSES = {
45+
"Int8": Integer,
46+
"Int16": Integer,
47+
"Int32": Integer,
48+
"Int64": Integer,
49+
"Int128": Integer,
50+
"Int256": Integer,
51+
"UInt8": Integer,
52+
"UInt16": Integer,
53+
"UInt32": Integer,
54+
"UInt64": Integer,
55+
"UInt128": Integer,
56+
"UInt256": Integer,
57+
"Float32": Float,
58+
"Float64": Float,
59+
"Decimal": Decimal,
60+
"UUID": Native_UUID,
61+
"String": Text,
62+
"FixedString": Text,
63+
"DateTime": Timestamp,
64+
"DateTime64": Timestamp,
65+
"Bool": Boolean,
66+
}
67+
68+
def quote(self, s: str) -> str:
69+
return f'"{s}"'
70+
71+
def to_string(self, s: str) -> str:
72+
return f"toString({s})"
73+
74+
def _convert_db_precision_to_digits(self, p: int) -> int:
75+
# Done the same as for PostgreSQL but need to rewrite in another way
76+
# because it does not help for float with a big integer part.
77+
return super()._convert_db_precision_to_digits(p) - 2
78+
79+
def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
80+
nullable_prefix = "Nullable("
81+
if type_repr.startswith(nullable_prefix):
82+
type_repr = type_repr[len(nullable_prefix) :].rstrip(")")
83+
84+
if type_repr.startswith("Decimal"):
85+
type_repr = "Decimal"
86+
elif type_repr.startswith("FixedString"):
87+
type_repr = "FixedString"
88+
elif type_repr.startswith("DateTime64"):
89+
type_repr = "DateTime64"
90+
91+
return self.TYPE_CLASSES.get(type_repr)
92+
93+
# def timestamp_value(self, t: DbTime) -> str:
94+
# # return f"'{t}'"
95+
# return f"'{str(t)[:19]}'"
96+
97+
def set_timezone_to_utc(self) -> str:
98+
raise NotImplementedError()
99+
100+
def current_timestamp(self) -> str:
101+
return "now()"
102+
43103
def md5_as_int(self, s: str) -> str:
44104
substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS
45105
return (
46106
f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx}))))) - {CHECKSUM_OFFSET}"
47107
)
48108

49-
50-
@attrs.define(frozen=False)
51-
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
52109
def normalize_number(self, value: str, coltype: FractionalType) -> str:
53110
# If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped.
54111
# For example:
@@ -106,70 +163,6 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
106163
return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')"
107164

108165

109-
@attrs.define(frozen=False)
110-
class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
111-
name = "Clickhouse"
112-
ROUNDS_ON_PREC_LOSS = False
113-
TYPE_CLASSES = {
114-
"Int8": Integer,
115-
"Int16": Integer,
116-
"Int32": Integer,
117-
"Int64": Integer,
118-
"Int128": Integer,
119-
"Int256": Integer,
120-
"UInt8": Integer,
121-
"UInt16": Integer,
122-
"UInt32": Integer,
123-
"UInt64": Integer,
124-
"UInt128": Integer,
125-
"UInt256": Integer,
126-
"Float32": Float,
127-
"Float64": Float,
128-
"Decimal": Decimal,
129-
"UUID": Native_UUID,
130-
"String": Text,
131-
"FixedString": Text,
132-
"DateTime": Timestamp,
133-
"DateTime64": Timestamp,
134-
"Bool": Boolean,
135-
}
136-
137-
def quote(self, s: str) -> str:
138-
return f'"{s}"'
139-
140-
def to_string(self, s: str) -> str:
141-
return f"toString({s})"
142-
143-
def _convert_db_precision_to_digits(self, p: int) -> int:
144-
# Done the same as for PostgreSQL but need to rewrite in another way
145-
# because it does not help for float with a big integer part.
146-
return super()._convert_db_precision_to_digits(p) - 2
147-
148-
def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
149-
nullable_prefix = "Nullable("
150-
if type_repr.startswith(nullable_prefix):
151-
type_repr = type_repr[len(nullable_prefix) :].rstrip(")")
152-
153-
if type_repr.startswith("Decimal"):
154-
type_repr = "Decimal"
155-
elif type_repr.startswith("FixedString"):
156-
type_repr = "FixedString"
157-
elif type_repr.startswith("DateTime64"):
158-
type_repr = "DateTime64"
159-
160-
return self.TYPE_CLASSES.get(type_repr)
161-
162-
# def timestamp_value(self, t: DbTime) -> str:
163-
# # return f"'{t}'"
164-
# return f"'{str(t)[:19]}'"
165-
166-
def set_timezone_to_utc(self) -> str:
167-
raise NotImplementedError()
168-
169-
def current_timestamp(self) -> str:
170-
return "now()"
171-
172-
173166
@attrs.define(frozen=False, init=False, kw_only=True)
174167
class Clickhouse(ThreadedDatabase):
175168
dialect = Dialect()

0 commit comments

Comments
 (0)