Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
)
from .converter import SnowflakeConverter
from .crl import CRLConfig
from .crl_cache import CRLCacheFactory
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor, SnowflakeCursorBase
from .description import (
CLIENT_NAME,
Expand Down Expand Up @@ -1152,11 +1153,17 @@ def connect(self, **kwargs) -> None:
else:
self.__open_connection()

# Register the connection in the pool after successful connection
_connections_pool.add_connection(self)

def close(self, retry: bool = True) -> None:
"""Closes the connection."""
# unregister to dereference connection object as it's already closed after the execution
atexit.unregister(self._close_at_exit)
try:
# Remove connection from the pool
_connections_pool.remove_connection(self)

if not self.rest:
logger.debug("Rest object has been destroyed, cannot close session")
return
Expand Down Expand Up @@ -2533,3 +2540,53 @@ def _detect_application() -> None | str:
return "jupyter_notebook"
if "snowbooks" in sys.modules:
return "snowflake_notebook"


class _ConnectionsPool:
"""Thread-safe pool for tracking opened SnowflakeConnection instances.

This class maintains a registry of active connections using weak references
to avoid preventing garbage collection.
"""

def __init__(self):
"""Initialize the connections pool with an empty registry and a lock."""
self._connections: weakref.WeakSet = weakref.WeakSet()
self._lock = Lock()

def add_connection(self, connection: SnowflakeConnection) -> None:
"""Add a connection to the pool.

Args:
connection: The SnowflakeConnection instance to register.
"""
with self._lock:
self._connections.add(connection)
logger.debug(
f"Connection {id(connection)} added to pool. Total connections: {len(self._connections)}"
)

def remove_connection(self, connection: SnowflakeConnection) -> None:
"""Remove a connection from the pool.

Args:
connection: The SnowflakeConnection instance to unregister.
"""
with self._lock:
self._connections.discard(connection)
logger.debug(
f"Connection {id(connection)} removed from pool. Total connections: {len(self._connections)}"
)

if len(self._connections) == 0:
# If no connections left then stop CRL background task
# to avoid script dangling
CRLCacheFactory.stop_periodic_cleanup()

def get_connection_count(self) -> int:
with self._lock:
return len(self._connections)


# Global instance of the connections pool
_connections_pool = _ConnectionsPool()
2 changes: 1 addition & 1 deletion src/snowflake/connector/crl.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class CRLConfig:
crl_cache_dir: Path | str | None = None
crl_cache_removal_delay_days: int = 7
crl_cache_cleanup_interval_hours: int = 1
crl_cache_start_cleanup: bool = False
crl_cache_start_cleanup: bool = True

@classmethod
def from_connection(cls, sf_connection) -> CRLConfig:
Expand Down
49 changes: 49 additions & 0 deletions test/unit/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,3 +953,52 @@ def test_connect_metadata_preservation():
len(params) > 0
), "connect should have parameters from SnowflakeConnection.__init__"
# Should have parameters like account, user, password, etc.


def test_connections_pool(mock_post_requests):
"""Test that connections are properly tracked in the _ConnectionsPool."""
from snowflake.connector.connection import _connections_pool

# Get initial connection count
initial_count = _connections_pool.get_connection_count()

# Create and connect first connection
conn1 = fake_connector()
assert (
_connections_pool.get_connection_count() == initial_count + 1
), "Connection count should increase by 1 after creating a connection"

# Create and connect second connection
conn2 = fake_connector()
assert (
_connections_pool.get_connection_count() == initial_count + 2
), "Connection count should increase by 2 after creating two connections"

# Close first connection
conn1.close()
assert (
_connections_pool.get_connection_count() == initial_count + 1
), "Connection count should decrease by 1 after closing a connection"

# Close second connection
conn2.close()
assert (
_connections_pool.get_connection_count() == initial_count
), "Connection count should return to initial count after closing all connections"


@mock.patch("snowflake.connector.connection.CRLCacheFactory")
def test_connections_pool_stops_crl_task_if_empty(crl_mock, mock_post_requests):
"""Test the individual methods of _ConnectionsPool."""

# Create a connection
conn1 = fake_connector()
conn2 = fake_connector()

# Don't stop the task if pool is not empty
conn1.close()
crl_mock.stop_periodic_cleanup.assert_not_called()

# Stop the task if the pool is emptied
conn2.close()
crl_mock.stop_periodic_cleanup.assert_called_once()
Loading