@@ -43,11 +43,13 @@ class Connection(metaclass=ConnectionMeta):
4343 '_addr' , '_opts' , '_command_timeout' , '_listeners' ,
4444 '_server_version' , '_server_caps' , '_intro_query' ,
4545 '_reset_query' , '_proxy' , '_stmt_exclusive_section' ,
46- '_ssl_context' )
46+ '_max_cacheable_statement_size' , ' _ssl_context' )
4747
4848 def __init__ (self , protocol , transport , loop , addr , opts , * ,
4949 statement_cache_size , command_timeout ,
50- max_cached_statement_lifetime , ssl_context ):
50+ max_cached_statement_lifetime ,
51+ max_cacheable_statement_size ,
52+ ssl_context ):
5153 self ._protocol = protocol
5254 self ._transport = transport
5355 self ._loop = loop
@@ -61,6 +63,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
6163 self ._opts = opts
6264 self ._ssl_context = ssl_context
6365
66+ self ._max_cacheable_statement_size = max_cacheable_statement_size
6467 self ._stmt_cache = _StatementCache (
6568 loop = loop ,
6669 max_size = statement_cache_size ,
@@ -69,22 +72,6 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
6972
7073 self ._stmts_to_close = set ()
7174
72- if command_timeout is not None :
73- try :
74- if isinstance (command_timeout , bool ):
75- raise ValueError
76-
77- command_timeout = float (command_timeout )
78-
79- if command_timeout < 0 :
80- raise ValueError
81-
82- except ValueError :
83- raise ValueError (
84- 'invalid command_timeout value: '
85- 'expected non-negative float (got {!r})' .format (
86- command_timeout )) from None
87-
8875 self ._command_timeout = command_timeout
8976
9077 self ._listeners = {}
@@ -280,7 +267,16 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
280267 if statement is not None :
281268 return statement
282269
283- if self ._stmt_cache .get_max_size () or named :
270+ # Only use the cache when:
271+ # * `statement_cache_size` is greater than 0;
272+ # * query size is less than `max_cacheable_statement_size`.
273+ use_cache = self ._stmt_cache .get_max_size () > 0
274+ if (use_cache and
275+ self ._max_cacheable_statement_size and
276+ len (query ) > self ._max_cacheable_statement_size ):
277+ use_cache = False
278+
279+ if use_cache or named :
284280 stmt_name = self ._get_unique_id ('stmt' )
285281 else :
286282 stmt_name = ''
@@ -295,7 +291,8 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
295291 types = await self ._types_stmt .fetch (list (ready ))
296292 self ._protocol .get_settings ().register_data_types (types )
297293
298- self ._stmt_cache .put (query , statement )
294+ if use_cache :
295+ self ._stmt_cache .put (query , statement )
299296
300297 # If we've just created a new statement object, check if there
301298 # are any statements for GC.
@@ -721,6 +718,7 @@ async def connect(dsn=None, *,
721718 timeout = 60 ,
722719 statement_cache_size = 100 ,
723720 max_cached_statement_lifetime = 300 ,
721+ max_cacheable_statement_size = 1024 * 15 ,
724722 command_timeout = None ,
725723 ssl = None ,
726724 __connection_class__ = Connection ,
@@ -772,6 +770,11 @@ async def connect(dsn=None, *,
772770 in the cache. Pass ``0`` to allow statements be cached
773771 indefinitely.
774772
773+ :param int max_cacheable_statement_size:
774+ the maximum size of a statement that can be cached (15KiB by
775+ default). Pass ``0`` to allow all statements to be cached
776+ regardless of their size.
777+
775778 :param float command_timeout:
776779 the default timeout for operations on this connection
777780 (the default is no timeout).
@@ -807,6 +810,29 @@ async def connect(dsn=None, *,
807810 if loop is None :
808811 loop = asyncio .get_event_loop ()
809812
813+ local_vars = locals ()
814+ for var_name in {'max_cacheable_statement_size' ,
815+ 'max_cached_statement_lifetime' ,
816+ 'statement_cache_size' }:
817+ var_val = local_vars [var_name ]
818+ if var_val is None or isinstance (var_val , bool ) or var_val < 0 :
819+ raise ValueError (
820+ '{} is expected to be greater '
821+ 'or equal to 0, got {!r}' .format (var_name , var_val ))
822+
823+ if command_timeout is not None :
824+ try :
825+ if isinstance (command_timeout , bool ):
826+ raise ValueError
827+ command_timeout = float (command_timeout )
828+ if command_timeout < 0 :
829+ raise ValueError
830+ except ValueError :
831+ raise ValueError (
832+ 'invalid command_timeout value: '
833+ 'expected non-negative float (got {!r})' .format (
834+ command_timeout )) from None
835+
810836 addrs , opts = _parse_connect_params (
811837 dsn = dsn , host = host , port = port , user = user , password = password ,
812838 database = database , opts = opts )
@@ -855,6 +881,7 @@ async def connect(dsn=None, *,
855881 pr , tr , loop , addr , opts ,
856882 statement_cache_size = statement_cache_size ,
857883 max_cached_statement_lifetime = max_cached_statement_lifetime ,
884+ max_cacheable_statement_size = max_cacheable_statement_size ,
858885 command_timeout = command_timeout , ssl_context = ssl )
859886
860887 pr .set_connection (con )
0 commit comments