Skip to content

Commit 12e6f44

Browse files
mdesmethashhar
authored andcommitted
Add lazy evaluation of server_version_info
1 parent 2b9ca0c commit 12e6f44

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

tests/integration/test_sqlalchemy_integration.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import sqlalchemy as sqla
1616
from sqlalchemy.sql import and_, not_, or_
1717

18+
from tests.integration.conftest import trino_version
1819
from tests.unit.conftest import sqlalchemy_version
1920
from trino.sqlalchemy.datatype import JSON
2021

@@ -497,3 +498,24 @@ def test_get_view_names_raises(trino_connection):
497498

498499
with pytest.raises(sqla.exc.NoSuchTableError):
499500
sqla.inspect(engine).get_view_names(None)
501+
502+
503+
@pytest.mark.parametrize('trino_connection', ['system'], indirect=True)
504+
@pytest.mark.skipif(trino_version() == '351', reason="version() not supported in older Trino versions")
505+
def test_version_is_lazy(trino_connection):
506+
_, conn = trino_connection
507+
result = conn.execute(sqla.text("SELECT 1"))
508+
result.fetchall()
509+
num_queries = _num_queries_containing_string(conn, "SELECT version()")
510+
assert num_queries == 0
511+
version_info = conn.dialect.server_version_info
512+
assert isinstance(version_info, tuple)
513+
num_queries = _num_queries_containing_string(conn, "SELECT version()")
514+
assert num_queries == 1
515+
516+
517+
def _num_queries_containing_string(connection, query_string):
518+
statement = sqla.text("select query from system.runtime.queries order by query_id desc offset 1 limit 1")
519+
result = connection.execute(statement)
520+
rows = result.fetchall()
521+
return len(list(filter(lambda rec: query_string in rec[0], rows)))

trino/sqlalchemy/dialect.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -336,15 +336,20 @@ def has_sequence(self, connection: Connection, sequence_name: str, schema: str =
336336
"""Trino has no support for sequence. Returns False indicate that given sequence does not exists."""
337337
return False
338338

339-
def _get_server_version_info(self, connection: Connection) -> Any:
340-
query = "SELECT version()"
341-
try:
342-
res = connection.execute(sql.text(query))
343-
version = res.scalar()
344-
return tuple([version])
345-
except exc.ProgrammingError as e:
346-
logger.debug(f"Failed to get server version: {e.orig.message}")
347-
return None
339+
@classmethod
340+
def _get_server_version_info(cls, connection: Connection) -> Any:
341+
def get_server_version_info(_):
342+
query = "SELECT version()"
343+
try:
344+
res = connection.execute(sql.text(query))
345+
version = res.scalar()
346+
return tuple([version])
347+
except exc.ProgrammingError as e:
348+
logger.debug(f"Failed to get server version: {e.orig.message}")
349+
return None
350+
351+
# Make server_version_info lazy in order to only make HTTP calls if user explicitly requests it.
352+
cls.server_version_info = property(get_server_version_info, lambda instance, value: None)
348353

349354
def _raw_connection(self, connection: Union[Engine, Connection]) -> trino_dbapi.Connection:
350355
if isinstance(connection, Engine):

0 commit comments

Comments
 (0)