Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
aa8d8be
[async] Designed and implemented async connection context manager wit…
sfc-gh-fpawlowski Oct 28, 2025
738f6fa
[async] Improved async connection context manager with wrapping.
sfc-gh-fpawlowski Oct 28, 2025
66ec753
[async] Full coroutine protocol supported
sfc-gh-fpawlowski Oct 28, 2025
2277a3d
[async] Fixed and tests
sfc-gh-fpawlowski Oct 28, 2025
1aaa618
[async] Tests
sfc-gh-fpawlowski Oct 28, 2025
013390c
[async] Approach 1 - Idempotent connection.__aenter__ through checkin…
sfc-gh-fpawlowski Oct 28, 2025
aadca49
[async] Added docs
sfc-gh-fpawlowski Oct 28, 2025
967d0be
NO-SNOW: cleaned up code
sfc-gh-fpawlowski Nov 2, 2025
560b381
NO-SNOW: doc update
sfc-gh-fpawlowski Nov 2, 2025
7d99b42
NO-SNOW: doc update
sfc-gh-fpawlowski Nov 2, 2025
c96d001
NO-SNOW: failing expected interface of connect
sfc-gh-fpawlowski Nov 4, 2025
8c833e2
NO-SNOW: passing interface of connect synch. Failing asynch connect t…
sfc-gh-fpawlowski Nov 4, 2025
4265b81
NO-SNOW: passing interface of connect async - but ai just made those …
sfc-gh-fpawlowski Nov 4, 2025
d395e36
NO-SNOW: passing interface of connect async - but ai just made those …
sfc-gh-fpawlowski Nov 4, 2025
eb5728d
NO-SNOW: passing interface of connect async ig
sfc-gh-fpawlowski Nov 4, 2025
44b3953
NO-SNOW: Tests for metadata preservation added
sfc-gh-fpawlowski Nov 4, 2025
5aea30a
NO-SNOW: Simplify callable instance into function
sfc-gh-fpawlowski Nov 4, 2025
ffb118d
NO-SNOW: Simplify callable instance into function
sfc-gh-fpawlowski Nov 4, 2025
9dec634
SNOW-2671717: Removed file doc
sfc-gh-fpawlowski Nov 4, 2025
0b24d0d
NO-SNOW: remove comment
sfc-gh-fpawlowski Nov 4, 2025
e1cccfb
NO-SNOW: Docs improved
sfc-gh-fpawlowski Nov 4, 2025
5234ea9
NO-SNOW: Removed synch changes
sfc-gh-fpawlowski Nov 4, 2025
ef0e4fd
NO-SNOW: CLean up
sfc-gh-fpawlowski Nov 4, 2025
48123e7
NO-SNOW: final api
sfc-gh-fpawlowski Nov 5, 2025
7b196c3
SNOW-1675422: fixed broken jobs - duplicated leftover
sfc-gh-fpawlowski Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 151 additions & 4 deletions src/snowflake/connector/aio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from functools import wraps
from typing import Any, Coroutine, Generator, Protocol, TypeVar, runtime_checkable

from ._connection import SnowflakeConnection
from ._cursor import DictCursor, SnowflakeCursor

Expand All @@ -9,8 +12,152 @@
DictCursor,
]

# ============================================================================
# DESIGN NOTES:
#
# Pattern similar to aiohttp.ClientSession.request() which similarly returns
# an object that can be both awaited and used as an async context manager.
#
# The async connect function uses a wrapper to support both:
# 1. Direct awaiting: conn = await connect(...)
# 2. Async context manager: async with connect(...) as conn:
#
# connect: A function decorated with @wraps(SnowflakeConnection.__init__) that
# preserves metadata for IDE support, type checking, and introspection.
# Returns a _AsyncConnectContextManager instance when called.
#
# _AsyncConnectContextManager: Implements __await__ and __aenter__/__aexit__
# to support both patterns on the same awaitable.
#
# The @wraps decorator ensures that connect() has the same signature and
# documentation as SnowflakeConnection.__init__, making it behave identically
# to the sync snowflake.connector.connect function from an introspection POV.
#
# Metadata preservation is critical for IDE autocomplete, static type checkers,
# and documentation generation to work correctly on the async connect function.
# ============================================================================


T = TypeVar("T")


@runtime_checkable
class HybridCoroutineContextManager(Protocol[T]):
"""Protocol for a hybrid coroutine that is also an async context manager.

Combines the full coroutine protocol (PEP 492) with async context manager
protocol (PEP 343/492), allowing code that expects either interface to work
seamlessly with instances of this protocol.

This is used when external code needs to manage the coroutine lifecycle
(e.g., timeout handlers, async schedulers) or use it as a context manager.
"""

# Full Coroutine Protocol (PEP 492)
def send(self, __arg: Any) -> Any:
"""Send a value into the coroutine."""
...

def throw(
self,
__typ: type[BaseException],
__val: BaseException | None = None,
__tb: Any = None,
) -> Any:
"""Throw an exception into the coroutine."""
...

def close(self) -> None:
"""Close the coroutine."""
...

def __await__(self) -> Generator[Any, None, T]:
"""Return awaitable generator."""
...

def __iter__(self) -> Generator[Any, None, T]:
"""Iterate over the coroutine."""
...

# Async Context Manager Protocol (PEP 343)
async def __aenter__(self) -> T:
"""Async context manager entry."""
...

async def __aexit__(
self,
__exc_type: type[BaseException] | None,
__exc_val: BaseException | None,
__exc_tb: Any,
) -> bool | None:
"""Async context manager exit."""
...


class _AsyncConnectContextManager(HybridCoroutineContextManager[SnowflakeConnection]):
"""Hybrid wrapper that enables both awaiting and async context manager usage.

Allows both patterns:
- conn = await connect(...)
- async with connect(...) as conn:

Implements the full coroutine protocol for maximum compatibility.
Satisfies the HybridCoroutineContextManager protocol.
"""

__slots__ = ("_coro", "_conn")

def __init__(self, coro: Coroutine[Any, Any, SnowflakeConnection]) -> None:
self._coro = coro
self._conn: SnowflakeConnection | None = None

def send(self, arg: Any) -> Any:
"""Send a value into the wrapped coroutine."""
return self._coro.send(arg)

def throw(self, *args: Any, **kwargs: Any) -> Any:
"""Throw an exception into the wrapped coroutine."""
return self._coro.throw(*args, **kwargs)

def close(self) -> None:
"""Close the wrapped coroutine."""
return self._coro.close()

def __await__(self) -> Generator[Any, None, SnowflakeConnection]:
"""Enable await connect(...)"""
return self._coro.__await__()

def __iter__(self) -> Generator[Any, None, SnowflakeConnection]:
"""Make the wrapper iterable like a coroutine."""
return self.__await__()

# This approach requires idempotent __aenter__ of SnowflakeConnection class - so check if connected and do not repeat connecting
async def __aenter__(self) -> SnowflakeConnection:
"""Enable async with connect(...) as conn:"""
self._conn = await self._coro
return await self._conn.__aenter__()

async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
"""Exit async context manager."""
if self._conn is not None:
return await self._conn.__aexit__(exc_type, exc, tb)
else:
return None


@wraps(SnowflakeConnection.__init__)
def connect(**kwargs: Any) -> HybridCoroutineContextManager[SnowflakeConnection]:
"""Create and connect to a Snowflake connection asynchronously.

Returns an awaitable that can also be used as an async context manager.
Supports both patterns:
- conn = await connect(...)
- async with connect(...) as conn:
"""

async def _connect_coro() -> SnowflakeConnection:
conn = SnowflakeConnection(**kwargs)
await conn.connect()
return conn

async def connect(**kwargs) -> SnowflakeConnection:
conn = SnowflakeConnection(**kwargs)
await conn.connect()
return conn
return _AsyncConnectContextManager(_connect_coro())
21 changes: 20 additions & 1 deletion src/snowflake/connector/aio/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,21 @@ def __init__(
connections_file_path: pathlib.Path | None = None,
**kwargs,
) -> None:
"""Create a new SnowflakeConnection.

Connections can be loaded from the TOML file located at
snowflake.connector.constants.CONNECTIONS_FILE.

When connection_name is supplied we will first load that connection
and then override any other values supplied.

When no arguments are given (other than connection_file_path) the
default connection will be loaded first. Note that no overwriting is
supported in this case.

If overwriting values from the default connection is desirable, supply
the name explicitly.
"""
# note we don't call super here because asyncio can not/is not recommended
# to perform async operation in the __init__ while in the sync connection we
# perform connect
Expand Down Expand Up @@ -173,7 +188,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):

async def __aenter__(self) -> SnowflakeConnection:
"""Context manager."""
await self.connect()
# Idempotent __aenter__ - required to be able to use both:
# - with snowflake.connector.aio.SnowflakeConnection(**k)
# - with snowflake.connector.aio.connect(**k)
if self.is_closed():
await self.connect()
return self

async def __aexit__(
Expand Down
26 changes: 15 additions & 11 deletions test/integ/aio_it/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
get_db_parameters,
is_public_testaccount,
)
from typing import AsyncContextManager, AsyncGenerator, Callable
from typing import Any, AsyncContextManager, AsyncGenerator, Callable

import pytest

from snowflake.connector.aio import SnowflakeConnection
from snowflake.connector.aio import connect as async_connect
from snowflake.connector.aio._telemetry import TelemetryClient
from snowflake.connector.connection import DefaultConverterClass
from snowflake.connector.telemetry import TelemetryData
Expand Down Expand Up @@ -70,13 +71,7 @@ def capture_sf_telemetry_async() -> TelemetryCaptureFixtureAsync:
return TelemetryCaptureFixtureAsync()


async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection:
"""Creates a connection using the parameters defined in parameters.py.

You can select from the different connections by supplying the appropiate
connection_name parameter and then anything else supplied will overwrite the values
from parameters.py.
"""
def fill_conn_kwargs_for_tests(connection_name: str, **kwargs) -> dict[str, Any]:
ret = get_db_parameters(connection_name)
ret.update(kwargs)

Expand All @@ -95,9 +90,18 @@ async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnecti
ret.pop("private_key", None)
ret.pop("private_key_file", None)

connection = SnowflakeConnection(**ret)
await connection.connect()
return connection
return ret


async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection:
"""Creates a connection using the parameters defined in parameters.py.

You can select from the different connections by supplying the appropiate
connection_name parameter and then anything else supplied will overwrite the values
from parameters.py.
"""
ret = fill_conn_kwargs_for_tests(connection_name, **kwargs)
return await async_connect(**ret)


@asynccontextmanager
Expand Down
60 changes: 59 additions & 1 deletion test/integ/aio_it/test_connection_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
CONNECTION_PARAMETERS_ADMIN = {}
from snowflake.connector.aio.auth import AuthByOkta, AuthByPlugin

from .conftest import create_connection
from .conftest import create_connection, fill_conn_kwargs_for_tests

try:
from snowflake.connector.errorcode import ER_FAILED_PROCESSING_QMARK
Expand Down Expand Up @@ -1466,6 +1466,64 @@ async def test_platform_detection_timeout(conn_cnx):
assert cnx.platform_detection_timeout_seconds == 2.5


@pytest.mark.skipolddriver
async def test_conn_cnx_basic(conn_cnx):
"""Tests platform detection timeout.

Creates a connection with platform_detection_timeout parameter.
"""
async with conn_cnx() as conn:
async with conn.cursor() as cur:
result = await (await cur.execute("select 1")).fetchall()
assert len(result) == 1
assert result[0][0] == 1


@pytest.mark.skipolddriver
async def test_conn_assigned_method(conn_cnx):
conn = await snowflake.connector.aio.connect(
**fill_conn_kwargs_for_tests("default")
)
async with conn.cursor() as cur:
result = await (await cur.execute("select 1")).fetchall()
assert len(result) == 1
assert result[0][0] == 1


@pytest.mark.skipolddriver
async def test_conn_assigned_class(conn_cnx):
conn = snowflake.connector.aio.SnowflakeConnection(
**fill_conn_kwargs_for_tests("default")
)
await conn.connect()
async with conn.cursor() as cur:
result = await (await cur.execute("select 1")).fetchall()
assert len(result) == 1
assert result[0][0] == 1


@pytest.mark.skipolddriver
async def test_conn_with_method(conn_cnx):
async with snowflake.connector.aio.connect(
**fill_conn_kwargs_for_tests("default")
) as conn:
async with conn.cursor() as cur:
result = await (await cur.execute("select 1")).fetchall()
assert len(result) == 1
assert result[0][0] == 1


@pytest.mark.skipolddriver
async def test_conn_with_class(conn_cnx):
async with snowflake.connector.aio.SnowflakeConnection(
**fill_conn_kwargs_for_tests("default")
) as conn:
async with conn.cursor() as cur:
result = await (await cur.execute("select 1")).fetchall()
assert len(result) == 1
assert result[0][0] == 1


@pytest.mark.skipolddriver
async def test_platform_detection_zero_timeout(conn_cnx):
with (
Expand Down
Loading
Loading