Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 79a419c

Browse files
committed
Set session timezone to UTC if possible
1 parent ab55f21 commit 79a419c

File tree

17 files changed

+73
-15
lines changed

17 files changed

+73
-15
lines changed

data_diff/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def _main(
317317
logging.error(e)
318318
return
319319

320+
320321
now: datetime = db1.query(current_timestamp(), datetime)
321322
now = now.replace(tzinfo=None)
322323
try:

data_diff/databases/_connect.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from data_diff.sqeleton.databases import Connect
1+
from data_diff.sqeleton.databases import Connect, Database
2+
import logging
23

34
from .postgresql import PostgreSQL
45
from .mysql import MySQL
@@ -29,4 +30,22 @@
2930
"vertica": Vertica,
3031
}
3132

32-
connect = Connect(DATABASE_BY_SCHEME)
33+
34+
class Connect_SetUTC(Connect):
35+
"""Provides methods for connecting to a supported database using a URL or connection dict.
36+
37+
Ensures all sessions use UTC Timezone, if possible.
38+
"""
39+
40+
def _connection_created(self, db):
41+
db = super()._connection_created(db)
42+
try:
43+
db.query(db.dialect.set_timezone_to_utc())
44+
except NotImplementedError:
45+
logging.debug(
46+
f"Database '{db}' does not allow setting timezone. We recommend making sure it's set to 'UTC'."
47+
)
48+
return db
49+
50+
51+
connect = Connect_SetUTC(DATABASE_BY_SCHEME)

data_diff/joindiff_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def sample(table_expr):
6262

6363
def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str:
6464
db = c.database
65-
c = c.replace(root=False) # we're compiling fragments, not full queries
65+
c = c.replace(root=False) # we're compiling fragments, not full queries
6666
if isinstance(db, BigQuery):
6767
return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}"
6868
elif isinstance(db, Presto):

data_diff/sqeleton/abcs/database_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ def timestamp_value(self, t: datetime) -> str:
203203
"Provide SQL for the given timestamp value"
204204
...
205205

206+
@abstractmethod
207+
def set_timezone_to_utc(self) -> str:
208+
"Provide SQL for setting the session timezone to UTC"
206209

207210
@abstractmethod
208211
def parse_type(

data_diff/sqeleton/databases/bigquery.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ def type_repr(self, t) -> str:
101101
except KeyError:
102102
return super().type_repr(t)
103103

104+
def set_timezone_to_utc(self) -> str:
105+
raise NotImplementedError()
106+
104107

105108
class BigQuery(Database):
106109
CONNECT_URI_HELP = "bigquery://<project>/<dataset>"

data_diff/sqeleton/databases/clickhouse.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
150150
# # return f"'{t}'"
151151
# return f"'{str(t)[:19]}'"
152152

153+
def set_timezone_to_utc(self) -> str:
154+
raise NotImplementedError()
155+
153156

154157
class Clickhouse(ThreadedDatabase):
155158
dialect = Dialect()

data_diff/sqeleton/databases/connect.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def match_path(self, dsn):
9393

9494

9595
class Connect:
96+
"""Provides methods for connecting to a supported database using a URL or connection dict."""
97+
9698
def __init__(self, database_by_scheme: Dict[str, Database]):
9799
self.database_by_scheme = database_by_scheme
98100
self.match_uri_path = {
@@ -172,9 +174,11 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa
172174
kw = {k: v for k, v in kw.items() if v is not None}
173175

174176
if issubclass(cls, ThreadedDatabase):
175-
return cls(thread_count=thread_count, **kw)
177+
db = cls(thread_count=thread_count, **kw)
178+
else:
179+
db = cls(**kw)
176180

177-
return cls(**kw)
181+
return self._connection_created(db)
178182

179183
def connect_with_dict(self, d, thread_count):
180184
d = dict(d)
@@ -186,9 +190,15 @@ def connect_with_dict(self, d, thread_count):
186190

187191
cls = matcher.database_cls
188192
if issubclass(cls, ThreadedDatabase):
189-
return cls(thread_count=thread_count, **d)
193+
db = cls(thread_count=thread_count, **d)
194+
else:
195+
db = cls(**d)
196+
197+
return self._connection_created(db)
190198

191-
return cls(**d)
199+
def _connection_created(self, db):
200+
"Nop function to be overridden by subclasses."
201+
return db
192202

193203
def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Database:
194204
"""Connect to a database using the given database configuration.

data_diff/sqeleton/databases/databricks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
8282
# Subtracting 2 due to wierd precision issues
8383
return max(super()._convert_db_precision_to_digits(p) - 2, 0)
8484

85+
def set_timezone_to_utc(self) -> str:
86+
return "SET TIME ZONE 'UTC'"
87+
8588

8689
class Databricks(ThreadedDatabase):
8790
dialect = Dialect()

data_diff/sqeleton/databases/duckdb.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ def parse_type(
108108

109109
return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale)
110110

111+
def set_timezone_to_utc(self) -> str:
112+
return "SET GLOBAL TimeZone='UTC'"
113+
111114

112115
class DuckDB(Database):
113116
dialect = Dialect()

data_diff/sqeleton/databases/mysql.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def type_repr(self, t) -> str:
9494
def explain_as_text(self, query: str) -> str:
9595
return f"EXPLAIN FORMAT=TREE {query}"
9696

97+
def set_timezone_to_utc(self) -> str:
98+
return "SET @@session.time_zone='+00:00'"
99+
97100

98101
class MySQL(ThreadedDatabase):
99102
dialect = Dialect()

0 commit comments

Comments
 (0)