11from __future__ import annotations
22
3+ import abc
34import asyncio
45import collections
56import logging
3940 ASYNC_NO_DATA_MAX_RETRY ,
4041 ASYNC_RETRY_PATTERN ,
4142 DESC_TABLE_RE ,
43+ ResultMetadata ,
44+ ResultMetadataV2 ,
45+ ResultState ,
4246)
43- from snowflake .connector .cursor import DictCursor as DictCursorSync
44- from snowflake .connector .cursor import ResultMetadata , ResultMetadataV2 , ResultState
45- from snowflake .connector .cursor import SnowflakeCursor as SnowflakeCursorSync
47+ from snowflake .connector .cursor import SnowflakeCursorBase as SnowflakeCursorBaseSync
4648from snowflake .connector .cursor import T
4749from snowflake .connector .errorcode import (
4850 ER_CURSOR_IS_CLOSED ,
6668
6769logger = getLogger (__name__ )
6870
71+ FetchRow = typing .TypeVar (
72+ "FetchRow" , bound = typing .Union [typing .Tuple [Any , ...], typing .Dict [str , Any ]]
73+ )
74+
6975
70- class SnowflakeCursor ( SnowflakeCursorSync ):
76+ class SnowflakeCursorBase ( SnowflakeCursorBaseSync , abc . ABC , typing . Generic [ FetchRow ] ):
7177 def __init__ (
7278 self ,
7379 connection : SnowflakeConnection ,
74- use_dict_result : bool = False ,
7580 ):
76- super ().__init__ (connection , use_dict_result )
81+ super ().__init__ (connection )
7782 # the following fixes type hint
7883 self ._connection = typing .cast ("SnowflakeConnection" , self ._connection )
79- self ._inner_cursor : SnowflakeCursor | None = None
84+ self ._inner_cursor : SnowflakeCursorBase | None = None
8085 self ._lock_canceling = asyncio .Lock ()
8186 self ._timebomb : asyncio .Task | None = None
8287 self ._prefetch_hook : typing .Callable [[], typing .Awaitable ] | None = None
@@ -894,8 +899,17 @@ async def describe(self, *args: Any, **kwargs: Any) -> list[ResultMetadata]:
894899 return None
895900 return [meta ._to_result_metadata_v1 () for meta in self ._description ]
896901
897- async def fetchone (self ) -> dict | tuple | None :
898- """Fetches one row."""
902+ @abc .abstractmethod
903+ async def fetchone (self ) -> FetchRow :
904+ pass
905+
906+ async def _fetchone (self ) -> dict [str , Any ] | tuple [Any , ...] | None :
907+ """
908+ Fetches one row.
909+
910+ Returns a dict if self._use_dict_result is True, otherwise
911+ returns tuple.
912+ """
899913 if self ._prefetch_hook is not None :
900914 await self ._prefetch_hook ()
901915 if self ._result is None and self ._result_set is not None :
@@ -920,7 +934,7 @@ async def fetchone(self) -> dict | tuple | None:
920934 else :
921935 return None
922936
923- async def fetchmany (self , size : int | None = None ) -> list [tuple ] | list [ dict ]:
937+ async def fetchmany (self , size : int | None = None ) -> list [FetchRow ]:
924938 """Fetches the number of specified rows."""
925939 if size is None :
926940 size = self .arraysize
@@ -1248,20 +1262,31 @@ async def wait_until_ready() -> None:
12481262 # Unset this function, so that we don't block anymore
12491263 self ._prefetch_hook = None
12501264
1251- if (
1252- self ._inner_cursor ._total_rowcount == 1
1253- and await self ._inner_cursor .fetchall ()
1254- == [("Multiple statements executed successfully." ,)]
1265+ if self ._inner_cursor ._total_rowcount == 1 and _is_successful_multi_stmt (
1266+ await self ._inner_cursor .fetchall ()
12551267 ):
12561268 url = f"/queries/{ sfqid } /result"
12571269 ret = await self ._connection .rest .request (url = url , method = "get" )
12581270 if "data" in ret and "resultIds" in ret ["data" ]:
12591271 await self ._init_multi_statement_results (ret ["data" ])
12601272
1273+ def _is_successful_multi_stmt (rows : list [Any ]) -> bool :
1274+ if len (rows ) != 1 :
1275+ return False
1276+ row = rows [0 ]
1277+ if isinstance (row , tuple ):
1278+ return row == ("Multiple statements executed successfully." ,)
1279+ elif isinstance (row , dict ):
1280+ return row == {
1281+ "multiple statement execution" : "Multiple statements executed successfully."
1282+ }
1283+ else :
1284+ return False
1285+
12611286 await self .connection .get_query_status_throw_if_error (
12621287 sfqid
12631288 ) # Trigger an exception if query failed
1264- self ._inner_cursor = SnowflakeCursor (self .connection )
1289+ self ._inner_cursor = self . __class__ (self .connection )
12651290 self ._sfqid = sfqid
12661291 self ._prefetch_hook = wait_until_ready
12671292
@@ -1324,5 +1349,50 @@ def _create_file_transfer_agent(
13241349 )
13251350
13261351
1327- class DictCursor (DictCursorSync , SnowflakeCursor ):
1328- pass
1352+ class SnowflakeCursor (SnowflakeCursorBase [tuple [Any , ...]]):
1353+ """Implementation of Cursor object that is returned from Connection.cursor() method.
1354+
1355+ Attributes:
1356+ description: A list of namedtuples about metadata for all columns.
1357+ rowcount: The number of records updated or selected. If not clear, -1 is returned.
1358+ rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be
1359+ determined.
1360+ sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support.
1361+ sqlstate: Snowflake SQL State code.
1362+ timestamp_output_format: Snowflake timestamp_output_format for timestamps.
1363+ timestamp_ltz_output_format: Snowflake output format for LTZ timestamps.
1364+ timestamp_tz_output_format: Snowflake output format for TZ timestamps.
1365+ timestamp_ntz_output_format: Snowflake output format for NTZ timestamps.
1366+ date_output_format: Snowflake output format for dates.
1367+ time_output_format: Snowflake output format for times.
1368+ timezone: Snowflake timezone.
1369+ binary_output_format: Snowflake output format for binary fields.
1370+ arraysize: The default number of rows fetched by fetchmany.
1371+ connection: The connection object by which the cursor was created.
1372+ errorhandle: The class that handles error handling.
1373+ is_file_transfer: Whether, or not the current command is a put, or get.
1374+ """
1375+
1376+ @property
1377+ def _use_dict_result (self ) -> bool :
1378+ return False
1379+
1380+ async def fetchone (self ) -> tuple [Any , ...] | None :
1381+ row = await self ._fetchone ()
1382+ if not (row is None or isinstance (row , tuple )):
1383+ raise TypeError (f"fetchone got unexpected result: { row } " )
1384+ return row
1385+
1386+
1387+ class DictCursor (SnowflakeCursorBase [dict [str , Any ]]):
1388+ """Cursor returning results in a dictionary."""
1389+
1390+ @property
1391+ def _use_dict_result (self ) -> bool :
1392+ return True
1393+
1394+ async def fetchone (self ) -> dict [str , Any ] | None :
1395+ row = await self ._fetchone ()
1396+ if not (row is None or isinstance (row , dict )):
1397+ raise TypeError (f"fetchone got unexpected result: { row } " )
1398+ return row
0 commit comments