|
88 | 88 | ) |
89 | 89 | from .converter import SnowflakeConverter |
90 | 90 | from .crl import CRLConfig |
| 91 | +from .crl_cache import CRLCacheFactory |
91 | 92 | from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor, SnowflakeCursorBase |
92 | 93 | from .description import ( |
93 | 94 | CLIENT_NAME, |
@@ -1152,11 +1153,17 @@ def connect(self, **kwargs) -> None: |
1152 | 1153 | else: |
1153 | 1154 | self.__open_connection() |
1154 | 1155 |
|
| 1156 | + # Register the connection in the pool after successful connection |
| 1157 | + _connections_registry.add_connection(self) |
| 1158 | + |
1155 | 1159 | def close(self, retry: bool = True) -> None: |
1156 | 1160 | """Closes the connection.""" |
1157 | 1161 | # unregister to dereference connection object as it's already closed after the execution |
1158 | 1162 | atexit.unregister(self._close_at_exit) |
1159 | 1163 | try: |
| 1164 | + # Remove connection from the pool |
| 1165 | + _connections_registry.remove_connection(self) |
| 1166 | + |
1160 | 1167 | if not self.rest: |
1161 | 1168 | logger.debug("Rest object has been destroyed, cannot close session") |
1162 | 1169 | return |
@@ -2533,3 +2540,56 @@ def _detect_application() -> None | str: |
2533 | 2540 | return "jupyter_notebook" |
2534 | 2541 | if "snowbooks" in sys.modules: |
2535 | 2542 | return "snowflake_notebook" |
| 2543 | + |
| 2544 | + |
| 2545 | +class _ConnectionsRegistry: |
| 2546 | + """Thread-safe registry for tracking opened SnowflakeConnection instances. |
| 2547 | +
|
| 2548 | + This class maintains a registry of active connections using weak references |
| 2549 | + to avoid preventing garbage collection. |
| 2550 | + """ |
| 2551 | + |
| 2552 | + def __init__(self): |
| 2553 | + """Initialize the connections registry with an empty registry and a lock.""" |
| 2554 | + self._connections: weakref.WeakSet = weakref.WeakSet() |
| 2555 | + self._lock = Lock() |
| 2556 | + |
| 2557 | + def add_connection(self, connection: SnowflakeConnection) -> None: |
| 2558 | + """Add a connection to the registry. |
| 2559 | +
|
| 2560 | + Args: |
| 2561 | + connection: The SnowflakeConnection instance to register. |
| 2562 | + """ |
| 2563 | + with self._lock: |
| 2564 | + self._connections.add(connection) |
| 2565 | + logger.debug( |
| 2566 | + f"Connection {id(connection)} added to pool. Total connections: {len(self._connections)}" |
| 2567 | + ) |
| 2568 | + |
| 2569 | + def remove_connection(self, connection: SnowflakeConnection) -> None: |
| 2570 | + """Remove a connection from the registry. |
| 2571 | +
|
| 2572 | + Args: |
| 2573 | + connection: The SnowflakeConnection instance to unregister. |
| 2574 | + """ |
| 2575 | + with self._lock: |
| 2576 | + self._connections.discard(connection) |
| 2577 | + logger.debug( |
| 2578 | + f"Connection {id(connection)} removed from registry. Total connections: {len(self._connections)}" |
| 2579 | + ) |
| 2580 | + |
| 2581 | + if len(self._connections) == 0: |
| 2582 | + self._last_connection_handler() |
| 2583 | + |
| 2584 | + def _last_connection_handler(self): |
| 2585 | + # If no connections left then stop CRL background task |
| 2586 | + # to avoid script dangling |
| 2587 | + CRLCacheFactory.stop_periodic_cleanup() |
| 2588 | + |
| 2589 | + def get_connection_count(self) -> int: |
| 2590 | + with self._lock: |
| 2591 | + return len(self._connections) |
| 2592 | + |
| 2593 | + |
| 2594 | +# Global instance of the connections pool |
| 2595 | +_connections_registry = _ConnectionsRegistry() |
0 commit comments