Skip to content

Commit 2b9ca0c

Browse files
john-bodleyhashhar
authored andcommitted
Fix incorrect time-zone in results after localization
Consider the following example using the canonical way to add zones to datetime objects: >>> import pytz >>> import datetime >>> import zoneinfo >>> datetime.datetime(2023, 1, 1, tzinfo=pytz.timezone("America/Los_Angeles")).isoformat() '2023-01-01T00:00:00-07:53' >>> datetime.datetime(2023, 1, 1, tzinfo=zoneinfo.ZoneInfo("America/Los_Angeles")).isoformat() '2023-01-01T00:00:00-08:00' pytz does eager timezone evaluation and uses the local-mean-time since the instant in time is not known. It requires an additional `localize` call to get the correct zone like so: >>> pytz.timezone("America/Los_Angeles").localize(datetime.datetime(2023, 1, 1)).isoformat() '2023-01-01T00:00:00-08:00' This increases chances of introducing bugs when writing idiomatic Python. The only reason to use pytz was because it allowed to control what happens with ambiguous datetimes but the standard library also allows provides control over that since 3.9 (and is available as backports.zoneinfo for older versions).
1 parent a031cc2 commit 2b9ca0c

File tree

5 files changed

+43
-40
lines changed

5 files changed

+43
-40
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
python_requires='>=3.7',
8181
install_requires=[
8282
"backports.zoneinfo;python_version<'3.9'",
83+
"python-dateutil",
8384
"pytz",
8485
"requests",
8586
"tzlocal",

tests/integration/test_dbapi_integration.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
from decimal import Decimal
1616
from typing import Tuple
1717

18+
try:
19+
from zoneinfo import ZoneInfo
20+
except ModuleNotFoundError:
21+
from backports.zoneinfo import ZoneInfo
22+
1823
import pytest
19-
import pytz
2024
import requests
2125
from tzlocal import get_localzone_name # type: ignore
2226

@@ -234,7 +238,7 @@ def test_legacy_primitive_types_with_connection_and_cursor(
234238
assert rows[0][0] == Decimal('0.142857')
235239
assert rows[0][1] == date(2018, 1, 1)
236240
assert rows[0][2] == datetime(2019, 1, 1, tzinfo=timezone(timedelta(hours=1)))
237-
assert rows[0][3] == datetime(2019, 1, 1, tzinfo=pytz.timezone('UTC'))
241+
assert rows[0][3] == datetime(2019, 1, 1, tzinfo=ZoneInfo('UTC'))
238242
assert rows[0][4] == datetime(2019, 1, 1)
239243
assert rows[0][5] == time(0, 0, 0, 0)
240244
else:
@@ -338,7 +342,7 @@ def test_datetime_query_param(trino_connection):
338342
def test_datetime_with_utc_time_zone_query_param(trino_connection):
339343
cur = trino_connection.cursor()
340344

341-
params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('UTC'))
345+
params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=ZoneInfo('UTC'))
342346

343347
cur.execute("SELECT ?", params=(params,))
344348
rows = cur.fetchall()
@@ -364,7 +368,7 @@ def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection):
364368
def test_datetime_with_named_time_zone_query_param(trino_connection):
365369
cur = trino_connection.cursor()
366370

367-
params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('America/Los_Angeles'))
371+
params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=ZoneInfo('America/Los_Angeles'))
368372

369373
cur.execute("SELECT ?", params=(params,))
370374
rows = cur.fetchall()
@@ -407,32 +411,24 @@ def test_datetimes_with_time_zone_in_dst_gap_query_param(trino_connection):
407411
cur = trino_connection.cursor()
408412

409413
# This is a datetime that lies within a DST transition and not actually exists.
410-
params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=pytz.timezone('Europe/Brussels'))
414+
params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=ZoneInfo('Europe/Brussels'))
411415
with pytest.raises(trino.exceptions.TrinoUserError):
412416
cur.execute("SELECT ?", params=(params,))
413417
cur.fetchall()
414418

415419

416-
def test_doubled_datetimes(trino_connection):
417-
# Trino doesn't distinguish between doubled datetimes that lie within a DST transition. See also
420+
@pytest.mark.parametrize('fold', [0, 1])
421+
def test_doubled_datetimes(trino_connection, fold):
422+
# Trino doesn't distinguish between doubled datetimes that lie within a DST transition.
418423
# See also https://github.com/trinodb/trino/issues/5781
419424
cur = trino_connection.cursor()
420425

421-
params = pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=True)
426+
params = datetime(2002, 10, 27, 1, 30, 0, tzinfo=ZoneInfo('US/Eastern'), fold=fold)
422427

423428
cur.execute("SELECT ?", params=(params,))
424429
rows = cur.fetchall()
425430

426-
assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern'))
427-
428-
cur = trino_connection.cursor()
429-
430-
params = pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=False)
431-
432-
cur.execute("SELECT ?", params=(params,))
433-
rows = cur.fetchall()
434-
435-
assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern'))
431+
assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=ZoneInfo('US/Eastern'))
436432

437433

438434
def test_date_query_param(trino_connection):
@@ -529,7 +525,7 @@ def test_time_query_param(trino_connection):
529525
def test_time_with_named_time_zone_query_param(trino_connection):
530526
cur = trino_connection.cursor()
531527

532-
params = time(16, 43, 22, 320000, tzinfo=pytz.timezone('Asia/Shanghai'))
528+
params = time(16, 43, 22, 320000, tzinfo=ZoneInfo('Asia/Shanghai'))
533529

534530
cur.execute("SELECT ?", params=(params,))
535531
rows = cur.fetchall()
@@ -693,7 +689,10 @@ def test_array_timestamp_query_param(trino_connection):
693689
def test_array_timestamp_with_timezone_query_param(trino_connection):
694690
cur = trino_connection.cursor()
695691

696-
params = [datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc)]
692+
params = [
693+
datetime(2020, 1, 1, 0, 0, 0, tzinfo=ZoneInfo('UTC')),
694+
datetime(2020, 1, 2, 0, 0, 0, tzinfo=ZoneInfo('UTC')),
695+
]
697696

698697
cur.execute("SELECT ?", params=(params,))
699698
rows = cur.fetchall()

tests/integration/test_types_integration.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
from datetime import date, datetime, time, timedelta, timezone, tzinfo
55
from decimal import Decimal
66

7+
try:
8+
from zoneinfo import ZoneInfo
9+
except ModuleNotFoundError:
10+
from backports.zoneinfo import ZoneInfo
11+
712
import pytest
8-
import pytz
913

1014
import trino
1115
from tests.integration.conftest import trino_version
@@ -729,7 +733,7 @@ def create_timezone(timezone_str: str) -> tzinfo:
729733
else:
730734
return timezone(-timedelta(hours=hours, minutes=minutes))
731735
else:
732-
return pytz.timezone(timezone_str)
736+
return ZoneInfo(timezone_str)
733737

734738

735739
def test_interval(trino_connection):

trino/client.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,18 @@
5151
from time import sleep
5252
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
5353

54-
import pytz
54+
try:
55+
from zoneinfo import ZoneInfo
56+
except ModuleNotFoundError:
57+
from backports.zoneinfo import ZoneInfo
58+
5559
import requests
56-
from pytz.tzinfo import BaseTzInfo
60+
from dateutil import tz
5761
from tzlocal import get_localzone_name # type: ignore
5862

5963
import trino.logging
6064
from trino import constants, exceptions
6165

62-
try:
63-
from zoneinfo import ZoneInfo # type: ignore
64-
65-
except ModuleNotFoundError:
66-
from backports.zoneinfo import ZoneInfo # type: ignore
67-
68-
6966
__all__ = ["ClientSession", "TrinoQuery", "TrinoRequest", "PROXIES"]
7067

7168
logger = trino.logging.get_logger(__name__)
@@ -946,7 +943,7 @@ def _create_tzinfo(timezone_str: str) -> tzinfo:
946943
return timezone(-timedelta(hours=int(hours), minutes=int(minutes)))
947944
return timezone(timedelta(hours=int(hours), minutes=int(minutes)))
948945
else:
949-
return pytz.timezone(timezone_str)
946+
return ZoneInfo(timezone_str)
950947

951948

952949
def _fraction_to_decimal(fractional_str: str) -> Decimal:
@@ -996,8 +993,7 @@ def add_time_delta(self, time_delta: timedelta) -> PythonTemporalType:
996993
def normalize(self, value: PythonTemporalType) -> PythonTemporalType:
997994
"""
998995
If `add_time_delta` results in value crossing DST boundaries, this method should
999-
return a normalized version of the value to account for it, for example,
1000-
using `pytz.timezone.normalize`.
996+
return a normalized version of the value to account for it.
1001997
"""
1002998
return value
1003999

@@ -1041,7 +1037,7 @@ def new_instance(self, value: datetime, fraction: Decimal) -> TimestampWithTimeZ
10411037
return TimestampWithTimeZone(value, fraction)
10421038

10431039
def normalize(self, value: datetime) -> datetime:
1044-
if isinstance(self._whole_python_temporal_value.tzinfo, BaseTzInfo):
1040+
if tz.datetime_ambiguous(value):
10451041
return self._whole_python_temporal_value.tzinfo.normalize(value)
10461042
return value
10471043

trino/dbapi.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types
2626
from urllib.parse import urlparse
2727

28-
import pytz
28+
try:
29+
from zoneinfo import ZoneInfo
30+
except ModuleNotFoundError:
31+
from backports.zoneinfo import ZoneInfo
2932

3033
import trino.client
3134
import trino.exceptions
@@ -425,8 +428,8 @@ def _format_prepared_param(self, param):
425428
if isinstance(param, datetime.datetime) and param.tzinfo is not None:
426429
datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f")
427430
# named timezones
428-
if hasattr(param.tzinfo, 'zone'):
429-
return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.zone)
431+
if isinstance(param.tzinfo, ZoneInfo):
432+
return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.key)
430433
# offset-based timezones
431434
return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.tzname(param))
432435

@@ -438,8 +441,8 @@ def _format_prepared_param(self, param):
438441
if isinstance(param, datetime.time) and param.tzinfo is not None:
439442
time_str = param.strftime("%H:%M:%S.%f")
440443
# named timezones
441-
if hasattr(param.tzinfo, 'zone'):
442-
utc_offset = datetime.datetime.now(pytz.timezone(param.tzinfo.zone)).strftime('%z')
444+
if isinstance(param.tzinfo, ZoneInfo):
445+
utc_offset = datetime.datetime.now(tz=param.tzinfo).strftime('%z')
443446
return "TIME '%s %s:%s'" % (time_str, utc_offset[:3], utc_offset[3:])
444447
# offset-based timezones
445448
return "TIME '%s %s'" % (time_str, param.strftime('%Z')[3:])

0 commit comments

Comments
 (0)