@@ -39,13 +39,14 @@ class Connection(metaclass=ConnectionMeta):
3939
4040 __slots__ = ('_protocol' , '_transport' , '_loop' , '_types_stmt' ,
4141 '_type_by_name_stmt' , '_top_xact' , '_uid' , '_aborted' ,
42- '_stmt_cache_max_size' , ' _stmt_cache' , '_stmts_to_close' ,
42+ '_stmt_cache' , '_stmts_to_close' ,
4343 '_addr' , '_opts' , '_command_timeout' , '_listeners' ,
4444 '_server_version' , '_server_caps' , '_intro_query' ,
4545 '_reset_query' , '_proxy' , '_stmt_exclusive_section' )
4646
4747 def __init__ (self , protocol , transport , loop , addr , opts , * ,
48- statement_cache_size , command_timeout ):
48+ statement_cache_size , command_timeout ,
49+ max_cached_statement_lifetime ):
4950 self ._protocol = protocol
5051 self ._transport = transport
5152 self ._loop = loop
@@ -58,8 +59,12 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
5859 self ._addr = addr
5960 self ._opts = opts
6061
61- self ._stmt_cache_max_size = statement_cache_size
62- self ._stmt_cache = collections .OrderedDict ()
62+ self ._stmt_cache = _StatementCache (
63+ loop = loop ,
64+ max_size = statement_cache_size ,
65+ on_remove = self ._maybe_gc_stmt ,
66+ max_lifetime = max_cached_statement_lifetime )
67+
6368 self ._stmts_to_close = set ()
6469
6570 if command_timeout is not None :
@@ -126,6 +131,8 @@ async def add_listener(self, channel, callback):
126131
127132 async def remove_listener (self , channel , callback ):
128133 """Remove a listening callback on the specified channel."""
134+ if self .is_closed ():
135+ return
129136 if channel not in self ._listeners :
130137 return
131138 if callback not in self ._listeners [channel ]:
@@ -266,46 +273,33 @@ async def executemany(self, command: str, args,
266273 return await self ._executemany (command , args , timeout )
267274
268275 async def _get_statement (self , query , timeout , * , named : bool = False ):
269- use_cache = self ._stmt_cache_max_size > 0
270- if use_cache :
271- try :
272- state = self ._stmt_cache [query ]
273- except KeyError :
274- pass
275- else :
276- self ._stmt_cache .move_to_end (query , last = True )
277- if not state .closed :
278- return state
279-
280- protocol = self ._protocol
276+ statement = self ._stmt_cache .get (query )
277+ if statement is not None :
278+ return statement
281279
282- if use_cache or named :
280+ if self . _stmt_cache . get_max_size () or named :
283281 stmt_name = self ._get_unique_id ('stmt' )
284282 else :
285283 stmt_name = ''
286284
287- state = await protocol .prepare (stmt_name , query , timeout )
285+ statement = await self . _protocol .prepare (stmt_name , query , timeout )
288286
289- ready = state ._init_types ()
287+ ready = statement ._init_types ()
290288 if ready is not True :
291289 if self ._types_stmt is None :
292290 self ._types_stmt = await self .prepare (self ._intro_query )
293291
294292 types = await self ._types_stmt .fetch (list (ready ))
295- protocol .get_settings ().register_data_types (types )
293+ self . _protocol .get_settings ().register_data_types (types )
296294
297- if use_cache :
298- if len (self ._stmt_cache ) > self ._stmt_cache_max_size - 1 :
299- old_query , old_state = self ._stmt_cache .popitem (last = False )
300- self ._maybe_gc_stmt (old_state )
301- self ._stmt_cache [query ] = state
295+ self ._stmt_cache .put (query , statement )
302296
303297 # If we've just created a new statement object, check if there
304298 # are any statements for GC.
305299 if self ._stmts_to_close :
306300 await self ._cleanup_stmts ()
307301
308- return state
302+ return statement
309303
310304 def cursor (self , query , * args , prefetch = None , timeout = None ):
311305 """Return a *cursor factory* for the specified query.
@@ -457,14 +451,14 @@ async def close(self):
457451 """Close the connection gracefully."""
458452 if self .is_closed ():
459453 return
460- self ._close_stmts ()
454+ self ._mark_stmts_as_closed ()
461455 self ._listeners = {}
462456 self ._aborted = True
463457 await self ._protocol .close ()
464458
465459 def terminate (self ):
466460 """Terminate the connection without waiting for pending data."""
467- self ._close_stmts ()
461+ self ._mark_stmts_as_closed ()
468462 self ._listeners = {}
469463 self ._aborted = True
470464 self ._protocol .abort ()
@@ -484,8 +478,8 @@ def _get_unique_id(self, prefix):
484478 self ._uid += 1
485479 return '__asyncpg_{}_{}__' .format (prefix , self ._uid )
486480
487- def _close_stmts (self ):
488- for stmt in self ._stmt_cache .values ():
481+ def _mark_stmts_as_closed (self ):
482+ for stmt in self ._stmt_cache .iter_statements ():
489483 stmt .mark_closed ()
490484
491485 for stmt in self ._stmts_to_close :
@@ -495,11 +489,22 @@ def _close_stmts(self):
495489 self ._stmts_to_close .clear ()
496490
497491 def _maybe_gc_stmt (self , stmt ):
498- if stmt .refs == 0 and stmt .query not in self ._stmt_cache :
492+ if stmt .refs == 0 and not self ._stmt_cache .has (stmt .query ):
493+ # If low-level `stmt` isn't referenced from any high-level
494+ # `PreparedStatament` object and is not in the `_stmt_cache`:
495+ #
496+ # * mark it as closed, which will make it non-usable
497+ # for any `PreparedStatament` or for methods like
498+ # `Connection.fetch()`.
499+ #
500+ # * schedule it to be formally closed on the server.
499501 stmt .mark_closed ()
500502 self ._stmts_to_close .add (stmt )
501503
502504 async def _cleanup_stmts (self ):
505+ # Called whenever we create a new prepared statement in
506+ # `Connection._get_statement()` and `_stmts_to_close` is
507+ # not empty.
503508 to_close = self ._stmts_to_close
504509 self ._stmts_to_close = set ()
505510 for stmt in to_close :
@@ -700,6 +705,7 @@ async def connect(dsn=None, *,
700705 loop = None ,
701706 timeout = 60 ,
702707 statement_cache_size = 100 ,
708+ max_cached_statement_lifetime = 300 ,
703709 command_timeout = None ,
704710 __connection_class__ = Connection ,
705711 ** opts ):
@@ -735,6 +741,12 @@ async def connect(dsn=None, *,
735741 :param float timeout: connection timeout in seconds.
736742
737743 :param int statement_cache_size: the size of prepared statement LRU cache.
744+ Pass ``0`` to disable the cache.
745+
746+ :param int max_cached_statement_lifetime:
747+ the maximum time in seconds a prepared statement will stay
748+ in the cache. Pass ``0`` to allow statements be cached
749+ indefinitely.
738750
739751 :param float command_timeout: the default timeout for operations on
740752 this connection (the default is no timeout).
@@ -753,6 +765,9 @@ async def connect(dsn=None, *,
753765 ... print(types)
754766 >>> asyncio.get_event_loop().run_until_complete(run())
755767 [<Record typname='bool' typnamespace=11 ...
768+
769+ .. versionchanged:: 0.10.0
770+ Added ``max_cached_statement_use_count`` parameter.
756771 """
757772 if loop is None :
758773 loop = asyncio .get_event_loop ()
@@ -796,13 +811,162 @@ async def connect(dsn=None, *,
796811 tr .close ()
797812 raise
798813
799- con = __connection_class__ (pr , tr , loop , addr , opts ,
800- statement_cache_size = statement_cache_size ,
801- command_timeout = command_timeout )
814+ con = __connection_class__ (
815+ pr , tr , loop , addr , opts ,
816+ statement_cache_size = statement_cache_size ,
817+ max_cached_statement_lifetime = max_cached_statement_lifetime ,
818+ command_timeout = command_timeout )
819+
802820 pr .set_connection (con )
803821 return con
804822
805823
824+ class _StatementCacheEntry :
825+
826+ __slots__ = ('_query' , '_statement' , '_cache' , '_cleanup_cb' )
827+
828+ def __init__ (self , cache , query , statement ):
829+ self ._cache = cache
830+ self ._query = query
831+ self ._statement = statement
832+ self ._cleanup_cb = None
833+
834+
835+ class _StatementCache :
836+
837+ __slots__ = ('_loop' , '_entries' , '_max_size' , '_on_remove' ,
838+ '_max_lifetime' )
839+
840+ def __init__ (self , * , loop , max_size , on_remove , max_lifetime ):
841+ self ._loop = loop
842+ self ._max_size = max_size
843+ self ._on_remove = on_remove
844+ self ._max_lifetime = max_lifetime
845+
846+ # We use an OrderedDict for LRU implementation. Operations:
847+ #
848+ # * We use a simple `__setitem__` to push a new entry:
849+ # `entries[key] = new_entry`
850+ # That will push `new_entry` to the *end* of the entries dict.
851+ #
852+ # * When we have a cache hit, we call
853+ # `entries.move_to_end(key, last=True)`
854+ # to move the entry to the *end* of the entries dict.
855+ #
856+ # * When we need to remove entries to maintain `max_size`, we call
857+ # `entries.popitem(last=False)`
858+ # to remove an entry from the *beginning* of the entries dict.
859+ #
860+ # So new entries and hits are always promoted to the end of the
861+ # entries dict, whereas the unused one will group in the
862+ # beginning of it.
863+ self ._entries = collections .OrderedDict ()
864+
865+ def __len__ (self ):
866+ return len (self ._entries )
867+
868+ def get_max_size (self ):
869+ return self ._max_size
870+
871+ def set_max_size (self , new_size ):
872+ assert new_size >= 0
873+ self ._max_size = new_size
874+ self ._maybe_cleanup ()
875+
876+ def get_max_lifetime (self ):
877+ return self ._max_lifetime
878+
879+ def set_max_lifetime (self , new_lifetime ):
880+ assert new_lifetime >= 0
881+ self ._max_lifetime = new_lifetime
882+ for entry in self ._entries .values ():
883+ # For every entry cancel the existing callback
884+ # and setup a new one if necessary.
885+ self ._set_entry_timeout (entry )
886+
887+ def get (self , query , * , promote = True ):
888+ if not self ._max_size :
889+ # The cache is disabled.
890+ return
891+
892+ entry = self ._entries .get (query ) # type: _StatementCacheEntry
893+ if entry is None :
894+ return
895+
896+ if entry ._statement .closed :
897+ # Happens in unittests when we call `stmt._state.mark_closed()`
898+ # manually.
899+ self ._entries .pop (query )
900+ self ._clear_entry_callback (entry )
901+ return
902+
903+ if promote :
904+ # `promote` is `False` when `get()` is called by `has()`.
905+ self ._entries .move_to_end (query , last = True )
906+
907+ return entry ._statement
908+
909+ def has (self , query ):
910+ return self .get (query , promote = False ) is not None
911+
912+ def put (self , query , statement ):
913+ if not self ._max_size :
914+ # The cache is disabled.
915+ return
916+
917+ self ._entries [query ] = self ._new_entry (query , statement )
918+
919+ # Check if the cache is bigger than max_size and trim it
920+ # if necessary.
921+ self ._maybe_cleanup ()
922+
923+ def iter_statements (self ):
924+ return (e ._statement for e in self ._entries .values ())
925+
926+ def clear (self ):
927+ # First, make sure that we cancel all scheduled callbacks.
928+ for entry in self ._entries .values ():
929+ self ._clear_entry_callback (entry )
930+
931+ # Clear the entries dict.
932+ self ._entries .clear ()
933+
934+ def _set_entry_timeout (self , entry ):
935+ # Clear the existing timeout.
936+ self ._clear_entry_callback (entry )
937+
938+ # Set the new timeout if it's not 0.
939+ if self ._max_lifetime :
940+ entry ._cleanup_cb = self ._loop .call_later (
941+ self ._max_lifetime , self ._on_entry_expired , entry )
942+
943+ def _new_entry (self , query , statement ):
944+ entry = _StatementCacheEntry (self , query , statement )
945+ self ._set_entry_timeout (entry )
946+ return entry
947+
948+ def _on_entry_expired (self , entry ):
949+ # `call_later` callback, called when an entry stayed longer
950+ # than `self._max_lifetime`.
951+ if self ._entries .get (entry ._query ) is entry :
952+ self ._entries .pop (entry ._query )
953+ self ._on_remove (entry ._statement )
954+
955+ def _clear_entry_callback (self , entry ):
956+ if entry ._cleanup_cb is not None :
957+ entry ._cleanup_cb .cancel ()
958+
959+ def _maybe_cleanup (self ):
960+ # Delete cache entries until the size of the cache is `max_size`.
961+ while len (self ._entries ) > self ._max_size :
962+ old_query , old_entry = self ._entries .popitem (last = False )
963+ self ._clear_entry_callback (old_entry )
964+
965+ # Let the connection know that the statement was removed
966+ # from the cache.
967+ self ._on_remove (old_entry ._statement )
968+
969+
806970class _Atomic :
807971 __slots__ = ('_acquired' ,)
808972
0 commit comments