Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit f657f47

Browse files
authored
Merge pull request #160 from datafold/pik94-support-databricks
Support Databricks 2
2 parents 35af3e0 + 13b94bb commit f657f47

File tree

8 files changed

+422
-41
lines changed

8 files changed

+422
-41
lines changed

README.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,22 +116,22 @@ $ data-diff \
116116

117117
## Supported Databases
118118

119-
| Database | Connection string | Status |
120-
|---------------|------------------------------------------------------------------------------------------------------------------------------------|--------|
121-
| PostgreSQL | `postgresql://<user>:<password>@<hostname>:5432/<database>` | 💚 |
122-
| MySQL | `mysql://<user>:<password>@<hostname>:5432/<database>` | 💚 |
123-
| Snowflake | `"snowflake://<user>[:<password>]@<account>/<database>/<SCHEMA>?warehouse=<WAREHOUSE>&role=<role>[&authenticator=externalbrowser]"`| 💚 |
124-
| Oracle | `oracle://<username>:<password>@<hostname>/database` | 💛 |
125-
| BigQuery | `bigquery://<project>/<dataset>` | 💛 |
126-
| Redshift | `redshift://<username>:<password>@<hostname>:5439/<database>` | 💛 |
127-
| Presto | `presto://<username>:<password>@<hostname>:8080/<database>` | 💛 |
128-
| ElasticSearch | | 📝 |
129-
| Databricks | | 📝 |
130-
| Planetscale | | 📝 |
131-
| Clickhouse | | 📝 |
132-
| Pinot | | 📝 |
133-
| Druid | | 📝 |
134-
| Kafka | | 📝 |
119+
| Database | Connection string | Status |
120+
|---------------|-------------------------------------------------------------------------------------------------------------------------------------|--------|
121+
| PostgreSQL | `postgresql://<user>:<password>@<hostname>:5432/<database>` | 💚 |
122+
| MySQL | `mysql://<user>:<password>@<hostname>:5432/<database>` | 💚 |
123+
| Snowflake | `"snowflake://<user>[:<password>]@<account>/<database>/<SCHEMA>?warehouse=<WAREHOUSE>&role=<role>[&authenticator=externalbrowser]"` | 💚 |
124+
| Oracle | `oracle://<username>:<password>@<hostname>/database` | 💛 |
125+
| BigQuery | `bigquery://<project>/<dataset>` | 💛 |
126+
| Redshift | `redshift://<username>:<password>@<hostname>:5439/<database>` | 💛 |
127+
| Presto | `presto://<username>:<password>@<hostname>:8080/<database>` | 💛 |
128+
| Databricks | `databricks://<http_path>:<access_token>@<server_hostname>/<catalog>/<schema>` | 💛 |
129+
| ElasticSearch | | 📝 | | 📝 |
130+
| Planetscale | | 📝 |
131+
| Clickhouse | | 📝 |
132+
| Pinot | | 📝 |
133+
| Druid | | 📝 |
134+
| Kafka | | 📝 |
135135

136136
* 💚: Implemented and thoroughly tested.
137137
* 💛: Implemented, but not thoroughly tested yet.

data_diff/databases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77
from .bigquery import BigQuery
88
from .redshift import Redshift
99
from .presto import Presto
10+
from .databricks import Databricks
1011

1112
from .connect import connect_to_uri

data_diff/databases/connect.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .bigquery import BigQuery
1313
from .redshift import Redshift
1414
from .presto import Presto
15+
from .databricks import Databricks
1516

1617

1718
@dataclass
@@ -77,6 +78,9 @@ def match_path(self, dsn):
7778
),
7879
"presto": MatchUriPath(Presto, ["catalog", "schema"], help_str="presto://<user>@<host>/<catalog>/<schema>"),
7980
"bigquery": MatchUriPath(BigQuery, ["dataset"], help_str="bigquery://<project>/<dataset>"),
81+
"databricks": MatchUriPath(
82+
Databricks, ["catalog", "schema"], help_str="databricks://:access_token@server_name/http_path",
83+
)
8084
}
8185

8286

@@ -100,6 +104,7 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
100104
- bigquery
101105
- redshift
102106
- presto
107+
- databricks
103108
"""
104109

105110
dsn = dsnparse.parse(db_uri)
@@ -113,23 +118,33 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
113118
raise NotImplementedError(f"Scheme {scheme} currently not supported")
114119

115120
cls = matcher.database_cls
116-
kw = matcher.match_path(dsn)
117121

118-
if scheme == "bigquery":
119-
kw["project"] = dsn.host
120-
return cls(**kw)
121-
122-
if scheme == "snowflake":
123-
kw["account"] = dsn.host
124-
assert not dsn.port
125-
kw["user"] = dsn.user
126-
kw["password"] = dsn.password
122+
if scheme == "databricks":
123+
assert not dsn.user
124+
kw = {}
125+
kw['access_token'] = dsn.password
126+
kw['http_path'] = dsn.path
127+
kw['server_hostname'] = dsn.host
128+
kw.update(dsn.query)
127129
else:
128-
kw["host"] = dsn.host
129-
kw["port"] = dsn.port
130-
kw["user"] = dsn.user
131-
if dsn.password:
130+
kw = matcher.match_path(dsn)
131+
132+
if scheme == "bigquery":
133+
kw["project"] = dsn.host
134+
return cls(**kw)
135+
136+
if scheme == "snowflake":
137+
kw["account"] = dsn.host
138+
assert not dsn.port
139+
kw["user"] = dsn.user
132140
kw["password"] = dsn.password
141+
else:
142+
kw["host"] = dsn.host
143+
kw["port"] = dsn.port
144+
kw["user"] = dsn.user
145+
if dsn.password:
146+
kw["password"] = dsn.password
147+
133148
kw = {k: v for k, v in kw.items() if v is not None}
134149

135150
if issubclass(cls, ThreadedDatabase):

data_diff/databases/databricks.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import logging
2+
import math
3+
4+
from .database_types import *
5+
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, _query_conn, parse_table_name
6+
7+
8+
@import_helper("databricks")
9+
def import_databricks():
10+
import databricks.sql
11+
12+
return databricks
13+
14+
15+
class Databricks(Database):
16+
TYPE_CLASSES = {
17+
# Numbers
18+
"INT": Integer,
19+
"SMALLINT": Integer,
20+
"TINYINT": Integer,
21+
"BIGINT": Integer,
22+
"FLOAT": Float,
23+
"DOUBLE": Float,
24+
"DECIMAL": Decimal,
25+
# Timestamps
26+
"TIMESTAMP": Timestamp,
27+
# Text
28+
"STRING": Text,
29+
}
30+
31+
ROUNDS_ON_PREC_LOSS = True
32+
33+
def __init__(
34+
self,
35+
http_path: str,
36+
access_token: str,
37+
server_hostname: str,
38+
catalog: str = "hive_metastore",
39+
schema: str = "default",
40+
**kwargs,
41+
):
42+
databricks = import_databricks()
43+
44+
self._conn = databricks.sql.connect(
45+
server_hostname=server_hostname, http_path=http_path, access_token=access_token
46+
)
47+
48+
logging.getLogger("databricks.sql").setLevel(logging.WARNING)
49+
50+
self.catalog = catalog
51+
self.default_schema = schema
52+
self.kwargs = kwargs
53+
54+
def _query(self, sql_code: str) -> list:
55+
"Uses the standard SQL cursor interface"
56+
return _query_conn(self._conn, sql_code)
57+
58+
def quote(self, s: str):
59+
return f"`{s}`"
60+
61+
def md5_to_int(self, s: str) -> str:
62+
return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))"
63+
64+
def to_string(self, s: str) -> str:
65+
return f"cast({s} as string)"
66+
67+
def _convert_db_precision_to_digits(self, p: int) -> int:
68+
# Subtracting 1 due to wierd precision issues
69+
return max(super()._convert_db_precision_to_digits(p) - 1, 0)
70+
71+
def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
72+
# Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL.
73+
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html
74+
# So, to obtain information about schema, we should use another approach.
75+
76+
schema, table = self._normalize_table_path(path)
77+
with self._conn.cursor() as cursor:
78+
cursor.columns(catalog_name=self.catalog, schema_name=schema, table_name=table)
79+
rows = cursor.fetchall()
80+
if not rows:
81+
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
82+
83+
if filter_columns is not None:
84+
accept = {i.lower() for i in filter_columns}
85+
rows = [r for r in rows if r.COLUMN_NAME.lower() in accept]
86+
87+
resulted_rows = []
88+
for row in rows:
89+
row_type = "DECIMAL" if row.DATA_TYPE == 3 else row.TYPE_NAME
90+
type_cls = self.TYPE_CLASSES.get(row_type, UnknownColType)
91+
92+
if issubclass(type_cls, Integer):
93+
row = (row.COLUMN_NAME, row_type, None, None, 0)
94+
95+
elif issubclass(type_cls, Float):
96+
numeric_precision = self._convert_db_precision_to_digits(row.DECIMAL_DIGITS)
97+
row = (row.COLUMN_NAME, row_type, None, numeric_precision, None)
98+
99+
elif issubclass(type_cls, Decimal):
100+
# TYPE_NAME has a format DECIMAL(x,y)
101+
items = row.TYPE_NAME[8:].rstrip(")").split(",")
102+
numeric_precision, numeric_scale = int(items[0]), int(items[1])
103+
row = (row.COLUMN_NAME, row_type, None, numeric_precision, numeric_scale)
104+
105+
elif issubclass(type_cls, Timestamp):
106+
row = (row.COLUMN_NAME, row_type, row.DECIMAL_DIGITS, None, None)
107+
108+
else:
109+
row = (row.COLUMN_NAME, row_type, None, None, None)
110+
111+
resulted_rows.append(row)
112+
col_dict: Dict[str, ColType] = {row[0]: self._parse_type(path, *row) for row in resulted_rows}
113+
114+
self._refine_coltypes(path, col_dict)
115+
return col_dict
116+
117+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
118+
"""Databricks timestamp contains no more than 6 digits in precision"""
119+
120+
if coltype.rounds:
121+
timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)"
122+
return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')"
123+
else:
124+
precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision)
125+
return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')"
126+
127+
def normalize_number(self, value: str, coltype: NumericType) -> str:
128+
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
129+
130+
def parse_table_name(self, name: str) -> DbPath:
131+
path = parse_table_name(name)
132+
return self._normalize_table_path(path)
133+
134+
def close(self):
135+
self._conn.close()

0 commit comments

Comments
 (0)