Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 706a34c

Browse files
committed
Add Trino connector
1 parent f657f47 commit 706a34c

File tree

15 files changed

+179
-3
lines changed

15 files changed

+179
-3
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,6 @@ benchmark_*.png
138138

139139
# Mac
140140
.DS_Store
141+
142+
# IntelliJ
143+
.idea

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,14 @@ $ data-diff \
125125
| BigQuery | `bigquery://<project>/<dataset>` | 💛 |
126126
| Redshift | `redshift://<username>:<password>@<hostname>:5439/<database>` | 💛 |
127127
| Presto | `presto://<username>:<password>@<hostname>:8080/<database>` | 💛 |
128+
<<<<<<< HEAD
128129
| Databricks | `databricks://<http_path>:<access_token>@<server_hostname>/<catalog>/<schema>` | 💛 |
129130
| ElasticSearch | | 📝 | | 📝 |
131+
=======
132+
| Trino | `trino://<username>:<password>@<hostname>:8080/<database>` | 💛 |
133+
| ElasticSearch | | 📝 |
134+
| Databricks | | 📝 |
135+
>>>>>>> 0d3fd47 (Add Trino connector)
130136
| Planetscale | | 📝 |
131137
| Clickhouse | | 📝 |
132138
| Pinot | | 📝 |
@@ -505,7 +511,7 @@ Now you can insert it into the testing database(s):
505511
```shell-session
506512
# It's optional to seed more than one to run data-diff(1) against.
507513
$ poetry run preql -f dev/prepare_db.pql mysql://mysql:Password1@127.0.0.1:3306/mysql
508-
$ poetry run preql -f dev/prepare_db.pql postgresql://postgres:Password1@127.0.0.1:5432/postgres
514+
$ poetry run preql -f dev/prepare_db.pql postgres://postgres:Password1@127.0.0.1:5432/postgres
509515
510516
# Cloud databases
511517
$ poetry run preql -f dev/prepare_db.pql snowflake://<uri>

data_diff/databases/connect.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .redshift import Redshift
1414
from .presto import Presto
1515
from .databricks import Databricks
16+
from .trino import Trino
1617

1718

1819
@dataclass
@@ -80,7 +81,8 @@ def match_path(self, dsn):
8081
"bigquery": MatchUriPath(BigQuery, ["dataset"], help_str="bigquery://<project>/<dataset>"),
8182
"databricks": MatchUriPath(
8283
Databricks, ["catalog", "schema"], help_str="databricks://:access_token@server_name/http_path",
83-
)
84+
),
85+
"trino": MatchUriPath(Trino, ["catalog", "schema"], help_str="trino://<user>@<host>/<catalog>/<schema>"),
8486
}
8587

8688

@@ -105,6 +107,7 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
105107
- redshift
106108
- presto
107109
- databricks
110+
- trino
108111
"""
109112

110113
dsn = dsnparse.parse(db_uri)

data_diff/databases/trino.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import re
2+
3+
from .database_types import *
4+
from .base import Database, import_helper
5+
from .base import (
6+
MD5_HEXDIGITS,
7+
CHECKSUM_HEXDIGITS,
8+
TIMESTAMP_PRECISION_POS,
9+
DEFAULT_DATETIME_PRECISION,
10+
)
11+
12+
13+
@import_helper("trino")
14+
def import_trino():
15+
import trino
16+
17+
return trino
18+
19+
20+
class Trino(Database):
21+
default_schema = "public"
22+
TYPE_CLASSES = {
23+
# Timestamps
24+
"timestamp with time zone": TimestampTZ,
25+
"timestamp without time zone": Timestamp,
26+
"timestamp": Timestamp,
27+
# Numbers
28+
"integer": Integer,
29+
"bigint": Integer,
30+
"real": Float,
31+
"double": Float,
32+
# Text
33+
"varchar": Text,
34+
}
35+
ROUNDS_ON_PREC_LOSS = True
36+
37+
def __init__(self, **kw):
38+
trino = import_trino()
39+
40+
self._conn = trino.dbapi.connect(**kw)
41+
42+
def quote(self, s: str):
43+
return f'"{s}"'
44+
45+
def md5_to_int(self, s: str) -> str:
46+
return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))"
47+
48+
def to_string(self, s: str):
49+
return f"cast({s} as varchar)"
50+
51+
def _query(self, sql_code: str) -> list:
52+
"""Uses the standard SQL cursor interface"""
53+
c = self._conn.cursor()
54+
c.execute(sql_code)
55+
if sql_code.lower().startswith("select"):
56+
return c.fetchall()
57+
58+
def close(self):
59+
self._conn.close()
60+
61+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
62+
if coltype.rounds:
63+
s = f"date_format(cast({coltype.precision} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
64+
else:
65+
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
66+
67+
return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')"
68+
69+
def normalize_number(self, value: str, coltype: FractionalType) -> str:
70+
return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))")
71+
72+
def select_table_schema(self, path: DbPath) -> str:
73+
schema, table = self._normalize_table_path(path)
74+
75+
return (
76+
f"SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision FROM INFORMATION_SCHEMA.COLUMNS "
77+
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
78+
)
79+
80+
def _parse_type(
81+
self,
82+
table_path: DbPath,
83+
col_name: str,
84+
type_repr: str,
85+
datetime_precision: int = None,
86+
numeric_precision: int = None,
87+
) -> ColType:
88+
timestamp_regexps = {
89+
r"timestamp\((\d)\)": Timestamp,
90+
r"timestamp\((\d)\) with time zone": TimestampTZ,
91+
}
92+
for regexp, t_cls in timestamp_regexps.items():
93+
m = re.match(regexp + "$", type_repr)
94+
if m:
95+
datetime_precision = int(m.group(1))
96+
return t_cls(
97+
precision=datetime_precision
98+
if datetime_precision is not None
99+
else DEFAULT_DATETIME_PRECISION,
100+
rounds=self.ROUNDS_ON_PREC_LOSS,
101+
)
102+
103+
number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal}
104+
for regexp, n_cls in number_regexps.items():
105+
m = re.match(regexp + "$", type_repr)
106+
if m:
107+
prec, scale = map(int, m.groups())
108+
return n_cls(scale)
109+
110+
string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text}
111+
for regexp, n_cls in string_regexps.items():
112+
m = re.match(regexp + "$", type_repr)
113+
if m:
114+
return n_cls()
115+
116+
return super()._parse_type(
117+
table_path, col_name, type_repr, datetime_precision, numeric_precision
118+
)
119+
120+
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
121+
return f"TRIM({value})"

data_diff/diff_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ class TableDiffer:
274274
The algorithm uses hashing to quickly check if the tables are different, and then applies a
275275
bisection search recursively to find the differences efficiently.
276276
277-
Works best for comparing tables that are mostly the name, with minor discrepencies.
277+
Works best for comparing tables that are mostly the same, with minor discrepencies.
278278
279279
Parameters:
280280
bisection_factor (int): Into how many segments to bisect per iteration.

debug.py

Whitespace-only changes.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
connector.name=memory
2+
memory.max-data-per-node=128MB
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
connector.name=postgresql
2+
connection-url=jdbc:postgresql://postgres:5432/postgres
3+
connection-user=postgres
4+
connection-password=Password1
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
connector.name=tpcds
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
connector.name=tpch

0 commit comments

Comments
 (0)