11from typing import Type , List , Optional , Union , Dict
22from itertools import zip_longest
33import dsnparse
4+ from contextlib import suppress
45
56from runtype import dataclass
67
8+ from ..utils import WeakCache
79from .base import Database , ThreadedDatabase
810from .postgresql import PostgreSQL
911from .mysql import MySQL
1921from .duckdb import DuckDB
2022
2123
24+
2225@dataclass
2326class MatchUriPath :
2427 database_cls : Type [Database ]
2528 params : List [str ]
2629 kwparams : List [str ] = []
27- help_str : str
30+ help_str : str = "<unspecified>"
2831
2932 def __post_init__ (self ):
3033 assert self .params == self .database_cls .CONNECT_URI_PARAMS , self .params
@@ -101,6 +104,7 @@ def __init__(self, database_by_scheme: Dict[str, Database]):
101104 name : MatchUriPath (cls , cls .CONNECT_URI_PARAMS , cls .CONNECT_URI_KWPARAMS , help_str = cls .CONNECT_URI_HELP )
102105 for name , cls in database_by_scheme .items ()
103106 }
107+ self .conn_cache = WeakCache ()
104108
105109 def connect_to_uri (self , db_uri : str , thread_count : Optional [int ] = 1 ) -> Database :
106110 """Connect to the given database uri
@@ -200,7 +204,7 @@ def _connection_created(self, db):
200204 "Nop function to be overridden by subclasses."
201205 return db
202206
203- def __call__ (self , db_conf : Union [str , dict ], thread_count : Optional [int ] = 1 ) -> Database :
207+ def __call__ (self , db_conf : Union [str , dict ], thread_count : Optional [int ] = 1 , shared : bool = True ) -> Database :
204208 """Connect to a database using the given database configuration.
205209
206210 Configuration can be given either as a URI string, or as a dict of {option: value}.
@@ -213,6 +217,7 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -
213217 Parameters:
214218 db_conf (str | dict): The configuration for the database to connect. URI or dict.
215219 thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1)
220+ shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True)
216221
217222 Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck.
218223
@@ -235,8 +240,19 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -
235240 >>> connect({"driver": "mysql", "host": "localhost", "database": "db"})
236241 <data_diff.databases.mysql.MySQL object at 0x0000025DB3F94820>
237242 """
243+ if shared :
244+ with suppress (KeyError ):
245+ conn = self .conn_cache .get (db_conf )
246+ if not conn .is_closed :
247+ return conn
248+
238249 if isinstance (db_conf , str ):
239- return self .connect_to_uri (db_conf , thread_count )
250+ conn = self .connect_to_uri (db_conf , thread_count )
240251 elif isinstance (db_conf , dict ):
241- return self .connect_with_dict (db_conf , thread_count )
242- raise TypeError (f"db configuration must be a URI string or a dictionary. Instead got '{ db_conf } '." )
252+ conn = self .connect_with_dict (db_conf , thread_count )
253+ else :
254+ raise TypeError (f"db configuration must be a URI string or a dictionary. Instead got '{ db_conf } '." )
255+
256+ if shared :
257+ self .conn_cache .add (db_conf , conn )
258+ return conn
0 commit comments