Skip to content

Commit 334bc33

Browse files
sfc-gh-bchinnsfc-gh-pczajka
authored andcommitted
SNOW-2333702 Fix types for DictCursor (#2532)
Co-authored-by: Patryk Czajka <patryk.czajka@snowflake.com>
1 parent 6b77055 commit 334bc33

File tree

8 files changed

+220
-77
lines changed

8 files changed

+220
-77
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
import os
88
import pathlib
99
import sys
10+
import typing
1011
import uuid
1112
import warnings
1213
from contextlib import suppress
1314
from io import StringIO
1415
from logging import getLogger
1516
from types import TracebackType
16-
from typing import Any, AsyncIterator, Iterable
17+
from typing import Any, AsyncIterator, Iterable, TypeVar
1718

1819
from snowflake.connector import (
1920
DatabaseError,
@@ -72,7 +73,7 @@
7273
from ..time_util import get_time_millis
7374
from ..util_text import split_statements
7475
from ..wif_util import AttestationProvider
75-
from ._cursor import SnowflakeCursor
76+
from ._cursor import SnowflakeCursor, SnowflakeCursorBase
7677
from ._description import CLIENT_NAME
7778
from ._direct_file_operation_utils import FileOperationParser, StreamDownloader
7879
from ._network import SnowflakeRestful
@@ -107,6 +108,11 @@
107108
DEFAULT_CONFIGURATION = copy.deepcopy(DEFAULT_CONFIGURATION_SYNC)
108109
DEFAULT_CONFIGURATION["application"] = (CLIENT_NAME, (type(None), str))
109110

111+
if sys.version_info >= (3, 13) or typing.TYPE_CHECKING:
112+
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase, default=SnowflakeCursor)
113+
else:
114+
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase)
115+
110116

111117
class SnowflakeConnection(SnowflakeConnectionSync):
112118
OCSP_ENV_LOCK = asyncio.Lock()
@@ -1032,9 +1038,7 @@ async def connect(self, **kwargs) -> None:
10321038
self._telemetry = TelemetryClient(self._rest)
10331039
await self._log_telemetry_imported_packages()
10341040

1035-
def cursor(
1036-
self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor
1037-
) -> SnowflakeCursor:
1041+
def cursor(self, cursor_class: type[CursorCls] = SnowflakeCursor) -> CursorCls:
10381042
logger.debug("cursor")
10391043
if not self.rest:
10401044
Error.errorhandler_wrapper(

src/snowflake/connector/aio/_cursor.py

Lines changed: 87 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import abc
34
import asyncio
45
import collections
56
import logging
@@ -39,10 +40,11 @@
3940
ASYNC_NO_DATA_MAX_RETRY,
4041
ASYNC_RETRY_PATTERN,
4142
DESC_TABLE_RE,
43+
ResultMetadata,
44+
ResultMetadataV2,
45+
ResultState,
4246
)
43-
from snowflake.connector.cursor import DictCursor as DictCursorSync
44-
from snowflake.connector.cursor import ResultMetadata, ResultMetadataV2, ResultState
45-
from snowflake.connector.cursor import SnowflakeCursor as SnowflakeCursorSync
47+
from snowflake.connector.cursor import SnowflakeCursorBase as SnowflakeCursorBaseSync
4648
from snowflake.connector.cursor import T
4749
from snowflake.connector.errorcode import (
4850
ER_CURSOR_IS_CLOSED,
@@ -66,17 +68,20 @@
6668

6769
logger = getLogger(__name__)
6870

71+
FetchRow = typing.TypeVar(
72+
"FetchRow", bound=typing.Union[typing.Tuple[Any, ...], typing.Dict[str, Any]]
73+
)
74+
6975

70-
class SnowflakeCursor(SnowflakeCursorSync):
76+
class SnowflakeCursorBase(SnowflakeCursorBaseSync, abc.ABC, typing.Generic[FetchRow]):
7177
def __init__(
7278
self,
7379
connection: SnowflakeConnection,
74-
use_dict_result: bool = False,
7580
):
76-
super().__init__(connection, use_dict_result)
81+
super().__init__(connection)
7782
# the following fixes type hint
7883
self._connection = typing.cast("SnowflakeConnection", self._connection)
79-
self._inner_cursor: SnowflakeCursor | None = None
84+
self._inner_cursor: SnowflakeCursorBase | None = None
8085
self._lock_canceling = asyncio.Lock()
8186
self._timebomb: asyncio.Task | None = None
8287
self._prefetch_hook: typing.Callable[[], typing.Awaitable] | None = None
@@ -894,8 +899,17 @@ async def describe(self, *args: Any, **kwargs: Any) -> list[ResultMetadata]:
894899
return None
895900
return [meta._to_result_metadata_v1() for meta in self._description]
896901

897-
async def fetchone(self) -> dict | tuple | None:
898-
"""Fetches one row."""
902+
@abc.abstractmethod
903+
async def fetchone(self) -> FetchRow:
904+
pass
905+
906+
async def _fetchone(self) -> dict[str, Any] | tuple[Any, ...] | None:
907+
"""
908+
Fetches one row.
909+
910+
Returns a dict if self._use_dict_result is True, otherwise
911+
returns tuple.
912+
"""
899913
if self._prefetch_hook is not None:
900914
await self._prefetch_hook()
901915
if self._result is None and self._result_set is not None:
@@ -920,7 +934,7 @@ async def fetchone(self) -> dict | tuple | None:
920934
else:
921935
return None
922936

923-
async def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
937+
async def fetchmany(self, size: int | None = None) -> list[FetchRow]:
924938
"""Fetches the number of specified rows."""
925939
if size is None:
926940
size = self.arraysize
@@ -1248,20 +1262,31 @@ async def wait_until_ready() -> None:
12481262
# Unset this function, so that we don't block anymore
12491263
self._prefetch_hook = None
12501264

1251-
if (
1252-
self._inner_cursor._total_rowcount == 1
1253-
and await self._inner_cursor.fetchall()
1254-
== [("Multiple statements executed successfully.",)]
1265+
if self._inner_cursor._total_rowcount == 1 and _is_successful_multi_stmt(
1266+
await self._inner_cursor.fetchall()
12551267
):
12561268
url = f"/queries/{sfqid}/result"
12571269
ret = await self._connection.rest.request(url=url, method="get")
12581270
if "data" in ret and "resultIds" in ret["data"]:
12591271
await self._init_multi_statement_results(ret["data"])
12601272

1273+
def _is_successful_multi_stmt(rows: list[Any]) -> bool:
1274+
if len(rows) != 1:
1275+
return False
1276+
row = rows[0]
1277+
if isinstance(row, tuple):
1278+
return row == ("Multiple statements executed successfully.",)
1279+
elif isinstance(row, dict):
1280+
return row == {
1281+
"multiple statement execution": "Multiple statements executed successfully."
1282+
}
1283+
else:
1284+
return False
1285+
12611286
await self.connection.get_query_status_throw_if_error(
12621287
sfqid
12631288
) # Trigger an exception if query failed
1264-
self._inner_cursor = SnowflakeCursor(self.connection)
1289+
self._inner_cursor = self.__class__(self.connection)
12651290
self._sfqid = sfqid
12661291
self._prefetch_hook = wait_until_ready
12671292

@@ -1324,5 +1349,50 @@ def _create_file_transfer_agent(
13241349
)
13251350

13261351

1327-
class DictCursor(DictCursorSync, SnowflakeCursor):
1328-
pass
1352+
class SnowflakeCursor(SnowflakeCursorBase[tuple[Any, ...]]):
1353+
"""Implementation of Cursor object that is returned from Connection.cursor() method.
1354+
1355+
Attributes:
1356+
description: A list of namedtuples about metadata for all columns.
1357+
rowcount: The number of records updated or selected. If not clear, -1 is returned.
1358+
rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be
1359+
determined.
1360+
sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support.
1361+
sqlstate: Snowflake SQL State code.
1362+
timestamp_output_format: Snowflake timestamp_output_format for timestamps.
1363+
timestamp_ltz_output_format: Snowflake output format for LTZ timestamps.
1364+
timestamp_tz_output_format: Snowflake output format for TZ timestamps.
1365+
timestamp_ntz_output_format: Snowflake output format for NTZ timestamps.
1366+
date_output_format: Snowflake output format for dates.
1367+
time_output_format: Snowflake output format for times.
1368+
timezone: Snowflake timezone.
1369+
binary_output_format: Snowflake output format for binary fields.
1370+
arraysize: The default number of rows fetched by fetchmany.
1371+
connection: The connection object by which the cursor was created.
1372+
errorhandle: The class that handles error handling.
1373+
is_file_transfer: Whether, or not the current command is a put, or get.
1374+
"""
1375+
1376+
@property
1377+
def _use_dict_result(self) -> bool:
1378+
return False
1379+
1380+
async def fetchone(self) -> tuple[Any, ...] | None:
1381+
row = await self._fetchone()
1382+
if not (row is None or isinstance(row, tuple)):
1383+
raise TypeError(f"fetchone got unexpected result: {row}")
1384+
return row
1385+
1386+
1387+
class DictCursor(SnowflakeCursorBase[dict[str, Any]]):
1388+
"""Cursor returning results in a dictionary."""
1389+
1390+
@property
1391+
def _use_dict_result(self) -> bool:
1392+
return True
1393+
1394+
async def fetchone(self) -> dict[str, Any] | None:
1395+
row = await self._fetchone()
1396+
if not (row is None or isinstance(row, dict)):
1397+
raise TypeError(f"fetchone got unexpected result: {row}")
1398+
return row

src/snowflake/connector/connection.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import re
99
import sys
1010
import traceback
11+
import typing
1112
import uuid
1213
import warnings
1314
import weakref
@@ -20,7 +21,16 @@
2021
from logging import getLogger
2122
from threading import Lock
2223
from types import TracebackType
23-
from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence
24+
from typing import (
25+
Any,
26+
Callable,
27+
Generator,
28+
Iterable,
29+
Iterator,
30+
NamedTuple,
31+
Sequence,
32+
TypeVar,
33+
)
2434
from uuid import UUID
2535

2636
from cryptography.hazmat.backends import default_backend
@@ -76,7 +86,7 @@
7686
QueryStatus,
7787
)
7888
from .converter import SnowflakeConverter
79-
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor
89+
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor, SnowflakeCursorBase
8090
from .description import (
8191
CLIENT_NAME,
8292
CLIENT_VERSION,
@@ -125,6 +135,11 @@
125135
from .util_text import construct_hostname, parse_account, split_statements
126136
from .wif_util import AttestationProvider
127137

138+
if sys.version_info >= (3, 13) or typing.TYPE_CHECKING:
139+
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase, default=SnowflakeCursor)
140+
else:
141+
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase)
142+
128143
DEFAULT_CLIENT_PREFETCH_THREADS = 4
129144
MAX_CLIENT_PREFETCH_THREADS = 10
130145
MAX_CLIENT_FETCH_THREADS = 1024
@@ -1059,9 +1074,7 @@ def rollback(self) -> None:
10591074
"""Rolls back the current transaction."""
10601075
self.cursor().execute("ROLLBACK")
10611076

1062-
def cursor(
1063-
self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor
1064-
) -> SnowflakeCursor:
1077+
def cursor(self, cursor_class: type[CursorCls] = SnowflakeCursor) -> CursorCls:
10651078
"""Creates a cursor object. Each statement will be executed in a new cursor object."""
10661079
logger.debug("cursor")
10671080
if not self.rest:

0 commit comments

Comments
 (0)