diff --git a/src/snowflake/connector/aio/__init__.py b/src/snowflake/connector/aio/__init__.py index 0b0410eba..15a0fe469 100644 --- a/src/snowflake/connector/aio/__init__.py +++ b/src/snowflake/connector/aio/__init__.py @@ -1,5 +1,8 @@ from __future__ import annotations +from functools import wraps +from typing import Any, Coroutine, Generator, Protocol, TypeVar, runtime_checkable + from ._connection import SnowflakeConnection from ._cursor import DictCursor, SnowflakeCursor @@ -9,8 +12,152 @@ DictCursor, ] +# ============================================================================ +# DESIGN NOTES: +# +# Pattern similar to aiohttp.ClientSession.request() which similarly returns +# an object that can be both awaited and used as an async context manager. +# +# The async connect function uses a wrapper to support both: +# 1. Direct awaiting: conn = await connect(...) +# 2. Async context manager: async with connect(...) as conn: +# +# connect: A function decorated with @wraps(SnowflakeConnection.__init__) that +# preserves metadata for IDE support, type checking, and introspection. +# Returns a _AsyncConnectContextManager instance when called. +# +# _AsyncConnectContextManager: Implements __await__ and __aenter__/__aexit__ +# to support both patterns on the same awaitable. +# +# The @wraps decorator ensures that connect() has the same signature and +# documentation as SnowflakeConnection.__init__, making it behave identically +# to the sync snowflake.connector.connect function from an introspection POV. +# +# Metadata preservation is critical for IDE autocomplete, static type checkers, +# and documentation generation to work correctly on the async connect function. +# ============================================================================ + + +T = TypeVar("T") + + +@runtime_checkable +class HybridCoroutineContextManager(Protocol[T]): + """Protocol for a hybrid coroutine that is also an async context manager. + + Combines the full coroutine protocol (PEP 492) with async context manager + protocol (PEP 343/492), allowing code that expects either interface to work + seamlessly with instances of this protocol. + + This is used when external code needs to manage the coroutine lifecycle + (e.g., timeout handlers, async schedulers) or use it as a context manager. + """ + + # Full Coroutine Protocol (PEP 492) + def send(self, __arg: Any) -> Any: + """Send a value into the coroutine.""" + ... + + def throw( + self, + __typ: type[BaseException], + __val: BaseException | None = None, + __tb: Any = None, + ) -> Any: + """Throw an exception into the coroutine.""" + ... + + def close(self) -> None: + """Close the coroutine.""" + ... + + def __await__(self) -> Generator[Any, None, T]: + """Return awaitable generator.""" + ... + + def __iter__(self) -> Generator[Any, None, T]: + """Iterate over the coroutine.""" + ... + + # Async Context Manager Protocol (PEP 343) + async def __aenter__(self) -> T: + """Async context manager entry.""" + ... + + async def __aexit__( + self, + __exc_type: type[BaseException] | None, + __exc_val: BaseException | None, + __exc_tb: Any, + ) -> bool | None: + """Async context manager exit.""" + ... + + +class _AsyncConnectContextManager(HybridCoroutineContextManager[SnowflakeConnection]): + """Hybrid wrapper that enables both awaiting and async context manager usage. + + Allows both patterns: + - conn = await connect(...) + - async with connect(...) as conn: + + Implements the full coroutine protocol for maximum compatibility. + Satisfies the HybridCoroutineContextManager protocol. + """ + + __slots__ = ("_coro", "_conn") + + def __init__(self, coro: Coroutine[Any, Any, SnowflakeConnection]) -> None: + self._coro = coro + self._conn: SnowflakeConnection | None = None + + def send(self, arg: Any) -> Any: + """Send a value into the wrapped coroutine.""" + return self._coro.send(arg) + + def throw(self, *args: Any, **kwargs: Any) -> Any: + """Throw an exception into the wrapped coroutine.""" + return self._coro.throw(*args, **kwargs) + + def close(self) -> None: + """Close the wrapped coroutine.""" + return self._coro.close() + + def __await__(self) -> Generator[Any, None, SnowflakeConnection]: + """Enable await connect(...)""" + return self._coro.__await__() + + def __iter__(self) -> Generator[Any, None, SnowflakeConnection]: + """Make the wrapper iterable like a coroutine.""" + return self.__await__() + + # This approach requires idempotent __aenter__ of SnowflakeConnection class - so check if connected and do not repeat connecting + async def __aenter__(self) -> SnowflakeConnection: + """Enable async with connect(...) as conn:""" + self._conn = await self._coro + return await self._conn.__aenter__() + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + """Exit async context manager.""" + if self._conn is not None: + return await self._conn.__aexit__(exc_type, exc, tb) + else: + return None + + +@wraps(SnowflakeConnection.__init__) +def connect(**kwargs: Any) -> HybridCoroutineContextManager[SnowflakeConnection]: + """Create and connect to a Snowflake connection asynchronously. + + Returns an awaitable that can also be used as an async context manager. + Supports both patterns: + - conn = await connect(...) + - async with connect(...) as conn: + """ + + async def _connect_coro() -> SnowflakeConnection: + conn = SnowflakeConnection(**kwargs) + await conn.connect() + return conn -async def connect(**kwargs) -> SnowflakeConnection: - conn = SnowflakeConnection(**kwargs) - await conn.connect() - return conn + return _AsyncConnectContextManager(_connect_coro()) diff --git a/src/snowflake/connector/aio/_connection.py b/src/snowflake/connector/aio/_connection.py index 9e0f6a926..3b85f4df4 100644 --- a/src/snowflake/connector/aio/_connection.py +++ b/src/snowflake/connector/aio/_connection.py @@ -123,6 +123,21 @@ def __init__( connections_file_path: pathlib.Path | None = None, **kwargs, ) -> None: + """Create a new SnowflakeConnection. + + Connections can be loaded from the TOML file located at + snowflake.connector.constants.CONNECTIONS_FILE. + + When connection_name is supplied we will first load that connection + and then override any other values supplied. + + When no arguments are given (other than connection_file_path) the + default connection will be loaded first. Note that no overwriting is + supported in this case. + + If overwriting values from the default connection is desirable, supply + the name explicitly. + """ # note we don't call super here because asyncio can not/is not recommended # to perform async operation in the __init__ while in the sync connection we # perform connect @@ -173,7 +188,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): async def __aenter__(self) -> SnowflakeConnection: """Context manager.""" - await self.connect() + # Idempotent __aenter__ - required to be able to use both: + # - with snowflake.connector.aio.SnowflakeConnection(**k) + # - with snowflake.connector.aio.connect(**k) + if self.is_closed(): + await self.connect() return self async def __aexit__( diff --git a/test/integ/aio_it/conftest.py b/test/integ/aio_it/conftest.py index c3949c242..86df4efc8 100644 --- a/test/integ/aio_it/conftest.py +++ b/test/integ/aio_it/conftest.py @@ -9,11 +9,12 @@ get_db_parameters, is_public_testaccount, ) -from typing import AsyncContextManager, AsyncGenerator, Callable +from typing import Any, AsyncContextManager, AsyncGenerator, Callable import pytest from snowflake.connector.aio import SnowflakeConnection +from snowflake.connector.aio import connect as async_connect from snowflake.connector.aio._telemetry import TelemetryClient from snowflake.connector.connection import DefaultConverterClass from snowflake.connector.telemetry import TelemetryData @@ -70,13 +71,7 @@ def capture_sf_telemetry_async() -> TelemetryCaptureFixtureAsync: return TelemetryCaptureFixtureAsync() -async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection: - """Creates a connection using the parameters defined in parameters.py. - - You can select from the different connections by supplying the appropiate - connection_name parameter and then anything else supplied will overwrite the values - from parameters.py. - """ +def fill_conn_kwargs_for_tests(connection_name: str, **kwargs) -> dict[str, Any]: ret = get_db_parameters(connection_name) ret.update(kwargs) @@ -95,9 +90,18 @@ async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnecti ret.pop("private_key", None) ret.pop("private_key_file", None) - connection = SnowflakeConnection(**ret) - await connection.connect() - return connection + return ret + + +async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection: + """Creates a connection using the parameters defined in parameters.py. + + You can select from the different connections by supplying the appropiate + connection_name parameter and then anything else supplied will overwrite the values + from parameters.py. + """ + ret = fill_conn_kwargs_for_tests(connection_name, **kwargs) + return await async_connect(**ret) @asynccontextmanager diff --git a/test/integ/aio_it/test_connection_async.py b/test/integ/aio_it/test_connection_async.py index 357173bbf..de2497820 100644 --- a/test/integ/aio_it/test_connection_async.py +++ b/test/integ/aio_it/test_connection_async.py @@ -46,7 +46,7 @@ CONNECTION_PARAMETERS_ADMIN = {} from snowflake.connector.aio.auth import AuthByOkta, AuthByPlugin -from .conftest import create_connection +from .conftest import create_connection, fill_conn_kwargs_for_tests try: from snowflake.connector.errorcode import ER_FAILED_PROCESSING_QMARK @@ -1466,6 +1466,64 @@ async def test_platform_detection_timeout(conn_cnx): assert cnx.platform_detection_timeout_seconds == 2.5 +@pytest.mark.skipolddriver +async def test_conn_cnx_basic(conn_cnx): + """Tests platform detection timeout. + + Creates a connection with platform_detection_timeout parameter. + """ + async with conn_cnx() as conn: + async with conn.cursor() as cur: + result = await (await cur.execute("select 1")).fetchall() + assert len(result) == 1 + assert result[0][0] == 1 + + +@pytest.mark.skipolddriver +async def test_conn_assigned_method(conn_cnx): + conn = await snowflake.connector.aio.connect( + **fill_conn_kwargs_for_tests("default") + ) + async with conn.cursor() as cur: + result = await (await cur.execute("select 1")).fetchall() + assert len(result) == 1 + assert result[0][0] == 1 + + +@pytest.mark.skipolddriver +async def test_conn_assigned_class(conn_cnx): + conn = snowflake.connector.aio.SnowflakeConnection( + **fill_conn_kwargs_for_tests("default") + ) + await conn.connect() + async with conn.cursor() as cur: + result = await (await cur.execute("select 1")).fetchall() + assert len(result) == 1 + assert result[0][0] == 1 + + +@pytest.mark.skipolddriver +async def test_conn_with_method(conn_cnx): + async with snowflake.connector.aio.connect( + **fill_conn_kwargs_for_tests("default") + ) as conn: + async with conn.cursor() as cur: + result = await (await cur.execute("select 1")).fetchall() + assert len(result) == 1 + assert result[0][0] == 1 + + +@pytest.mark.skipolddriver +async def test_conn_with_class(conn_cnx): + async with snowflake.connector.aio.SnowflakeConnection( + **fill_conn_kwargs_for_tests("default") + ) as conn: + async with conn.cursor() as cur: + result = await (await cur.execute("select 1")).fetchall() + assert len(result) == 1 + assert result[0][0] == 1 + + @pytest.mark.skipolddriver async def test_platform_detection_zero_timeout(conn_cnx): with ( diff --git a/test/unit/aio/test_connection_async_unit.py b/test/unit/aio/test_connection_async_unit.py index 1253bef24..eedf562ae 100644 --- a/test/unit/aio/test_connection_async_unit.py +++ b/test/unit/aio/test_connection_async_unit.py @@ -932,3 +932,93 @@ async def test_large_query_through_proxy_async( "/queries/v1/query-request" in r["request"]["url"] for r in target_reqs["requests"] ) + + +@pytest.mark.skipolddriver +async def test_connect_metadata_preservation(): + """Test that the async connect function preserves metadata from SnowflakeConnection.__init__. + + This test verifies that various inspection methods return consistent metadata, + ensuring IDE support, type checking, and documentation generation work correctly. + """ + import inspect + + from snowflake.connector.aio import SnowflakeConnection, connect + + # Test 1: Check __name__ is correct + assert ( + connect.__name__ == "__init__" + ), f"connect.__name__ should be '__init__', but got '{connect.__name__}'" + assert ( + connect.__qualname__ == "SnowflakeConnection.__init__" + ), f"connect.__qualname__ should be 'connect', but got '{connect.__qualname__}'" + + # Test 2: Check __wrapped__ points to SnowflakeConnection.__init__ + assert hasattr(connect, "__wrapped__"), "connect should have __wrapped__ attribute" + assert ( + connect.__wrapped__ is SnowflakeConnection.__init__ + ), "connect.__wrapped__ should reference SnowflakeConnection.__init__" + + # Test 3: Check __module__ is preserved + assert hasattr(connect, "__module__"), "connect should have __module__ attribute" + assert connect.__module__ == SnowflakeConnection.__init__.__module__, ( + f"connect.__module__ should match SnowflakeConnection.__init__.__module__, " + f"but got '{connect.__module__}' vs '{SnowflakeConnection.__init__.__module__}'" + ) + + # Test 4: Check __doc__ is preserved + assert hasattr(connect, "__doc__"), "connect should have __doc__ attribute" + assert ( + connect.__doc__ == SnowflakeConnection.__init__.__doc__ + ), "connect.__doc__ should match SnowflakeConnection.__init__.__doc__" + + # Test 5: Check __annotations__ are preserved (or at least available) + assert hasattr( + connect, "__annotations__" + ), "connect should have __annotations__ attribute" + src_annotations = getattr(SnowflakeConnection.__init__, "__annotations__", {}) + connect_annotations = getattr(connect, "__annotations__", {}) + assert connect_annotations == src_annotations, ( + f"connect.__annotations__ should match SnowflakeConnection.__init__.__annotations__, " + f"but got {connect_annotations} vs {src_annotations}" + ) + + # Test 6: Check inspect.signature works correctly + try: + connect_sig = inspect.signature(connect) + source_sig = inspect.signature(SnowflakeConnection.__init__) + assert str(connect_sig) == str(source_sig), ( + f"inspect.signature(connect) should match inspect.signature(SnowflakeConnection.__init__), " + f"but got '{connect_sig}' vs '{source_sig}'" + ) + except Exception as e: + pytest.fail(f"inspect.signature(connect) failed: {e}") + + # Test 7: Check inspect.getdoc works correctly + connect_doc = inspect.getdoc(connect) + source_doc = inspect.getdoc(SnowflakeConnection.__init__) + assert ( + connect_doc == source_doc + ), "inspect.getdoc(connect) should match inspect.getdoc(SnowflakeConnection.__init__)" + + # Test 8: Check that connect is callable and returns expected type + assert callable(connect), "connect should be callable" + + # Test 9: Check type() and __class__ values (important for user introspection) + assert ( + type(connect).__name__ == "function" + ), f"type(connect).__name__ should be 'function', but got '{type(connect).__name__}'" + assert ( + connect.__class__.__name__ == "function" + ), f"connect.__class__.__name__ should be 'function', but got '{connect.__class__.__name__}'" + assert inspect.isfunction( + connect + ), "connect should be recognized as a function by inspect.isfunction()" + + # Test 10: Verify the function has proper introspection capabilities + # IDEs and type checkers should be able to resolve parameters + sig = inspect.signature(connect) + params = list(sig.parameters.keys()) + assert ( + len(params) > 0 + ), "connect should have parameters from SnowflakeConnection.__init__"