Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 01abf3a

Browse files
authored
Merge pull request #306 from datafold/issue284
data-diff now uses database A's now instead of cli's now.
2 parents 9289570 + 79a419c commit 01abf3a

26 files changed

+196
-104
lines changed

data_diff/__main__.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from copy import deepcopy
2+
from datetime import datetime
23
import sys
34
import time
45
import json
@@ -15,8 +16,9 @@
1516
from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer
1617
from .table_segment import TableSegment
1718
from .sqeleton.schema import create_schema
19+
from .sqeleton.queries.api import current_timestamp
1820
from .databases import connect
19-
from .parse_time import parse_time_before_now, UNITS_STR, ParseError
21+
from .parse_time import parse_time_before, UNITS_STR, ParseError
2022
from .config import apply_config_from_file
2123
from .tracking import disable_tracking
2224
from . import __version__
@@ -299,17 +301,6 @@ def _main(
299301

300302
start = time.monotonic()
301303

302-
try:
303-
options = dict(
304-
min_update=max_age and parse_time_before_now(max_age),
305-
max_update=min_age and parse_time_before_now(min_age),
306-
case_sensitive=case_sensitive,
307-
where=where,
308-
)
309-
except ParseError as e:
310-
logging.error(f"Error while parsing age expression: {e}")
311-
return
312-
313304
if database1 is None or database2 is None:
314305
logging.error(
315306
f"Error: Databases not specified. Got {database1} and {database2}. Use --help for more information."
@@ -326,6 +317,20 @@ def _main(
326317
logging.error(e)
327318
return
328319

320+
321+
now: datetime = db1.query(current_timestamp(), datetime)
322+
now = now.replace(tzinfo=None)
323+
try:
324+
options = dict(
325+
min_update=max_age and parse_time_before(now, max_age),
326+
max_update=min_age and parse_time_before(now, min_age),
327+
case_sensitive=case_sensitive,
328+
where=where,
329+
)
330+
except ParseError as e:
331+
logging.error(f"Error while parsing age expression: {e}")
332+
return
333+
329334
dbs = db1, db2
330335

331336
if interactive:

data_diff/databases/_connect.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from data_diff.sqeleton.databases import Connect
1+
from data_diff.sqeleton.databases import Connect, Database
2+
import logging
23

34
from .postgresql import PostgreSQL
45
from .mysql import MySQL
@@ -29,4 +30,22 @@
2930
"vertica": Vertica,
3031
}
3132

32-
connect = Connect(DATABASE_BY_SCHEME)
33+
34+
class Connect_SetUTC(Connect):
35+
"""Provides methods for connecting to a supported database using a URL or connection dict.
36+
37+
Ensures all sessions use UTC Timezone, if possible.
38+
"""
39+
40+
def _connection_created(self, db):
41+
db = super()._connection_created(db)
42+
try:
43+
db.query(db.dialect.set_timezone_to_utc())
44+
except NotImplementedError:
45+
logging.debug(
46+
f"Database '{db}' does not allow setting timezone. We recommend making sure it's set to 'UTC'."
47+
)
48+
return db
49+
50+
51+
connect = Connect_SetUTC(DATABASE_BY_SCHEME)

data_diff/joindiff_tables.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def sample(table_expr):
6262

6363
def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str:
6464
db = c.database
65+
c = c.replace(root=False) # we're compiling fragments, not full queries
6566
if isinstance(db, BigQuery):
6667
return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}"
6768
elif isinstance(db, Presto):

data_diff/parse_time.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,5 @@ def parse_time_delta(t: str):
7070
return timedelta(**time_dict)
7171

7272

73-
def parse_time_before_now(t: str):
74-
return datetime.now() - parse_time_delta(t)
73+
def parse_time_before(time: datetime, delta: str):
74+
return time - parse_time_delta(delta)

data_diff/sqeleton/abcs/database_types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ def to_string(self, s: str) -> str:
184184
def random(self) -> str:
185185
"Provide SQL for generating a random number betweein 0..1"
186186

187+
@abstractmethod
188+
def current_timestamp(self) -> str:
189+
"Provide SQL for returning the current timestamp, aka now"
190+
187191
@abstractmethod
188192
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
189193
"Provide SQL fragment for limit and offset inside a select"
@@ -199,6 +203,10 @@ def timestamp_value(self, t: datetime) -> str:
199203
"Provide SQL for the given timestamp value"
200204
...
201205

206+
@abstractmethod
207+
def set_timezone_to_utc(self) -> str:
208+
"Provide SQL for setting the session timezone to UTC"
209+
202210
@abstractmethod
203211
def parse_type(
204212
self,

data_diff/sqeleton/databases/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,10 @@ def timestamp_value(self, t: DbTime) -> str:
142142
return f"'{t.isoformat()}'"
143143

144144
def random(self) -> str:
145-
return "RANDOM()"
145+
return "random()"
146+
147+
def current_timestamp(self) -> str:
148+
return "current_timestamp()"
146149

147150
def explain_as_text(self, query: str) -> str:
148151
return f"EXPLAIN {query}"

data_diff/sqeleton/databases/bigquery.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ def type_repr(self, t) -> str:
101101
except KeyError:
102102
return super().type_repr(t)
103103

104+
def set_timezone_to_utc(self) -> str:
105+
raise NotImplementedError()
106+
104107

105108
class BigQuery(Database):
106109
CONNECT_URI_HELP = "bigquery://<project>/<dataset>"

data_diff/sqeleton/databases/clickhouse.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
150150
# # return f"'{t}'"
151151
# return f"'{str(t)[:19]}'"
152152

153+
def set_timezone_to_utc(self) -> str:
154+
raise NotImplementedError()
155+
153156

154157
class Clickhouse(ThreadedDatabase):
155158
dialect = Dialect()

data_diff/sqeleton/databases/connect.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def match_path(self, dsn):
9393

9494

9595
class Connect:
96+
"""Provides methods for connecting to a supported database using a URL or connection dict."""
97+
9698
def __init__(self, database_by_scheme: Dict[str, Database]):
9799
self.database_by_scheme = database_by_scheme
98100
self.match_uri_path = {
@@ -172,9 +174,11 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa
172174
kw = {k: v for k, v in kw.items() if v is not None}
173175

174176
if issubclass(cls, ThreadedDatabase):
175-
return cls(thread_count=thread_count, **kw)
177+
db = cls(thread_count=thread_count, **kw)
178+
else:
179+
db = cls(**kw)
176180

177-
return cls(**kw)
181+
return self._connection_created(db)
178182

179183
def connect_with_dict(self, d, thread_count):
180184
d = dict(d)
@@ -186,9 +190,15 @@ def connect_with_dict(self, d, thread_count):
186190

187191
cls = matcher.database_cls
188192
if issubclass(cls, ThreadedDatabase):
189-
return cls(thread_count=thread_count, **d)
193+
db = cls(thread_count=thread_count, **d)
194+
else:
195+
db = cls(**d)
196+
197+
return self._connection_created(db)
190198

191-
return cls(**d)
199+
def _connection_created(self, db):
200+
"Nop function to be overridden by subclasses."
201+
return db
192202

193203
def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Database:
194204
"""Connect to a database using the given database configuration.

data_diff/sqeleton/databases/databricks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
8282
# Subtracting 2 due to wierd precision issues
8383
return max(super()._convert_db_precision_to_digits(p) - 2, 0)
8484

85+
def set_timezone_to_utc(self) -> str:
86+
return "SET TIME ZONE 'UTC'"
87+
8588

8689
class Databricks(ThreadedDatabase):
8790
dialect = Dialect()

0 commit comments

Comments
 (0)