Skip to content

Commit 226b216

Browse files
lpoulainhashhar
authored andcommitted
Support UUID for SQLAlchemy
Support UUID for SQLAlchemy
1 parent c1a8f1e commit 226b216

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

tests/integration/test_sqlalchemy_integration.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License
12+
import uuid
13+
1214
import pytest
1315
import sqlalchemy as sqla
1416
from sqlalchemy.sql import and_, not_, or_
@@ -133,6 +135,60 @@ def test_insert(trino_connection):
133135
metadata.drop_all(engine)
134136

135137

138+
@pytest.mark.skipif(
139+
sqlalchemy_version() < "2.0",
140+
reason="sqlalchemy.Uuid only exists with SQLAlchemy 2.0 and above"
141+
)
142+
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
143+
def test_define_and_create_table_uuid(trino_connection):
144+
engine, conn = trino_connection
145+
if not engine.dialect.has_schema(conn, "test"):
146+
with engine.begin() as connection:
147+
connection.execute(sqla.schema.CreateSchema("test"))
148+
metadata = sqla.MetaData()
149+
try:
150+
sqla.Table('users',
151+
metadata,
152+
sqla.Column('guid', sqla.Uuid),
153+
schema="test")
154+
metadata.create_all(engine)
155+
assert sqla.inspect(engine).has_table('users', schema="test")
156+
users = sqla.Table('users', metadata, schema='test', autoload_with=conn)
157+
assert_column(users, "guid", sqla.sql.sqltypes.Uuid)
158+
finally:
159+
metadata.drop_all(engine)
160+
161+
162+
@pytest.mark.skipif(
163+
sqlalchemy_version() < "2.0",
164+
reason="sqlalchemy.Uuid only exists with SQLAlchemy 2.0 and above"
165+
)
166+
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
167+
def test_insert_uuid(trino_connection):
168+
engine, conn = trino_connection
169+
170+
if not engine.dialect.has_schema(conn, "test"):
171+
with engine.begin() as connection:
172+
connection.execute(sqla.schema.CreateSchema("test"))
173+
metadata = sqla.MetaData()
174+
try:
175+
users = sqla.Table('users',
176+
metadata,
177+
sqla.Column('guid', sqla.Uuid),
178+
schema="test")
179+
metadata.create_all(engine)
180+
ins = users.insert()
181+
guid = uuid.uuid4()
182+
conn.execute(ins, {"guid": guid})
183+
query = sqla.select(users)
184+
result = conn.execute(query)
185+
rows = result.fetchall()
186+
assert len(rows) == 1
187+
assert rows[0] == (guid,)
188+
finally:
189+
metadata.drop_all(engine)
190+
191+
136192
@pytest.mark.skipif(
137193
sqlalchemy_version() < "1.4",
138194
reason="columns argument to select() must be a Python list or other iterable"

trino/sqlalchemy/datatype.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import re
1414
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
1515

16+
import sqlalchemy
1617
from sqlalchemy import util
1718
from sqlalchemy.sql import sqltypes
1819
from sqlalchemy.sql.type_api import TypeDecorator, TypeEngine
@@ -129,6 +130,9 @@ def get_col_spec(self, **kw):
129130
# 'tdigest': TDIGEST,
130131
}
131132

133+
if hasattr(sqlalchemy, "Uuid"):
134+
_type_map["uuid"] = sqlalchemy.Uuid
135+
132136

133137
def unquote(string: str, quote: str = '"', escape: str = "\\") -> str:
134138
"""

0 commit comments

Comments
 (0)