Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
)
from .converter import SnowflakeConverter
from .crl import CRLConfig
from .crl_cache import CRLCacheFactory
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor, SnowflakeCursorBase
from .description import (
CLIENT_NAME,
Expand Down Expand Up @@ -1152,11 +1153,17 @@ def connect(self, **kwargs) -> None:
else:
self.__open_connection()

# Register the connection in the pool after successful connection
_connections_registry.add_connection(self)

def close(self, retry: bool = True) -> None:
"""Closes the connection."""
# unregister to dereference connection object as it's already closed after the execution
atexit.unregister(self._close_at_exit)
try:
# Remove connection from the pool
_connections_registry.remove_connection(self)

if not self.rest:
logger.debug("Rest object has been destroyed, cannot close session")
return
Expand Down Expand Up @@ -2533,3 +2540,56 @@ def _detect_application() -> None | str:
return "jupyter_notebook"
if "snowbooks" in sys.modules:
return "snowflake_notebook"


class _ConnectionsRegistry:
"""Thread-safe registry for tracking opened SnowflakeConnection instances.

This class maintains a registry of active connections using weak references
to avoid preventing garbage collection.
"""

def __init__(self):
"""Initialize the connections registry with an empty registry and a lock."""
self._connections: weakref.WeakSet = weakref.WeakSet()
self._lock = Lock()

def add_connection(self, connection: SnowflakeConnection) -> None:
"""Add a connection to the registry.

Args:
connection: The SnowflakeConnection instance to register.
"""
with self._lock:
self._connections.add(connection)
logger.debug(
f"Connection {id(connection)} added to pool. Total connections: {len(self._connections)}"
)

def remove_connection(self, connection: SnowflakeConnection) -> None:
"""Remove a connection from the registry.

Args:
connection: The SnowflakeConnection instance to unregister.
"""
with self._lock:
self._connections.discard(connection)
logger.debug(
f"Connection {id(connection)} removed from registry. Total connections: {len(self._connections)}"
)

if len(self._connections) == 0:
self._last_connection_handler()

def _last_connection_handler(self):
# If no connections left then stop CRL background task
# to avoid script dangling
CRLCacheFactory.stop_periodic_cleanup()

def get_connection_count(self) -> int:
with self._lock:
return len(self._connections)


# Global instance of the connections pool
_connections_registry = _ConnectionsRegistry()
2 changes: 1 addition & 1 deletion src/snowflake/connector/crl.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class CRLConfig:
crl_cache_dir: Path | str | None = None
crl_cache_removal_delay_days: int = 7
crl_cache_cleanup_interval_hours: int = 1
crl_cache_start_cleanup: bool = False
crl_cache_start_cleanup: bool = True

@classmethod
def from_connection(cls, sf_connection) -> CRLConfig:
Expand Down
25 changes: 25 additions & 0 deletions test/unit/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,3 +953,28 @@ def test_connect_metadata_preservation():
len(params) > 0
), "connect should have parameters from SnowflakeConnection.__init__"
# Should have parameters like account, user, password, etc.


@mock.patch("snowflake.connector.connection.CRLCacheFactory")
def test_connections_registry_lifecycle(crl_mock, mock_post_requests):
"""Test the individual methods of _ConnectionsPool."""
from snowflake.connector.connection import _ConnectionsRegistry

# Mock the registry to avoid side effects from other tests due to _ConnectionsRegistry being a singleton
with mock.patch(
"snowflake.connector.connection._connections_registry", _ConnectionsRegistry()
) as mock_registry:
# Create a connection
conn1 = fake_connector()
conn2 = fake_connector()
assert mock_registry.get_connection_count() == 2

# Don't stop the task if pool is not empty
conn1.close()
crl_mock.stop_periodic_cleanup.assert_not_called()
assert mock_registry.get_connection_count() == 1

# Stop the task if the pool is emptied
conn2.close()
assert mock_registry.get_connection_count() == 0
crl_mock.stop_periodic_cleanup.assert_called_once()
Loading