@@ -166,8 +166,34 @@ def test_execute_many_select(trino_connection):
166166 assert "Query must return update type" in str (e .value )
167167
168168
169- def test_python_types_not_used_when_experimental_python_types_is_not_set (trino_connection ):
170- cur = trino_connection .cursor ()
169+ @pytest .mark .parametrize ("connection_experimental_python_types,cursor_experimental_python_types,expected" ,
170+ [
171+ (None , None , False ),
172+ (None , False , False ),
173+ (None , True , True ),
174+ (False , None , False ),
175+ (False , False , False ),
176+ (False , True , True ),
177+ (True , None , True ),
178+ (True , False , False ),
179+ (True , True , True ),
180+ ])
181+ def test_experimental_python_types_with_connection_and_cursor (
182+ connection_experimental_python_types ,
183+ cursor_experimental_python_types ,
184+ expected ,
185+ run_trino
186+ ):
187+ _ , host , port = run_trino
188+
189+ connection = trino .dbapi .Connection (
190+ host = host ,
191+ port = port ,
192+ user = "test" ,
193+ experimental_python_types = connection_experimental_python_types ,
194+ )
195+
196+ cur = connection .cursor (experimental_python_types = cursor_experimental_python_types )
171197
172198 cur .execute ("""
173199 SELECT
@@ -180,15 +206,23 @@ def test_python_types_not_used_when_experimental_python_types_is_not_set(trino_c
180206 """ )
181207 rows = cur .fetchall ()
182208
183- for value in rows [0 ]:
184- assert isinstance (value , str )
185-
186- assert rows [0 ][0 ] == '0.142857'
187- assert rows [0 ][1 ] == '2018-01-01'
188- assert rows [0 ][2 ] == '2019-01-01 00:00:00.000 +01:00'
189- assert rows [0 ][3 ] == '2019-01-01 00:00:00.000 UTC'
190- assert rows [0 ][4 ] == '2019-01-01 00:00:00.000'
191- assert rows [0 ][5 ] == '00:00:00.000'
209+ if expected :
210+ assert rows [0 ][0 ] == Decimal ('0.142857' )
211+ assert rows [0 ][1 ] == date (2018 , 1 , 1 )
212+ assert rows [0 ][2 ] == datetime (2019 , 1 , 1 , tzinfo = timezone (timedelta (hours = 1 )))
213+ assert rows [0 ][3 ] == datetime (2019 , 1 , 1 , tzinfo = pytz .timezone ('UTC' ))
214+ assert rows [0 ][4 ] == datetime (2019 , 1 , 1 )
215+ assert rows [0 ][5 ] == time (0 , 0 , 0 , 0 )
216+ else :
217+ for value in rows [0 ]:
218+ assert isinstance (value , str )
219+
220+ assert rows [0 ][0 ] == '0.142857'
221+ assert rows [0 ][1 ] == '2018-01-01'
222+ assert rows [0 ][2 ] == '2019-01-01 00:00:00.000 +01:00'
223+ assert rows [0 ][3 ] == '2019-01-01 00:00:00.000 UTC'
224+ assert rows [0 ][4 ] == '2019-01-01 00:00:00.000'
225+ assert rows [0 ][5 ] == '00:00:00.000'
192226
193227
194228def test_decimal_query_param (trino_connection ):
0 commit comments