55import threading
66import weakref
77from abc import abstractmethod
8+ from collections .abc import Callable , Iterable , Sequence
89from itertools import chain
910
1011# We need to explicitly import `getpid` from `os` instead of importing `os`. The
2324from os import getpid
2425from queue import Empty , Full , LifoQueue
2526from time import time
26- from typing import Any , Callable , List , Optional , Sequence , Type , Union
27+ from typing import Any , Optional , Union
2728
2829from ._cache import (
2930 DEFAULT_ALLOW_LIST ,
4647 ValkeyError ,
4748)
4849from .retry import Retry
49- from .typing import KeysT , ResponseT
50+ from .typing import KeyT , ResponseT
5051from .utils import (
5152 CRYPTOGRAPHY_AVAILABLE ,
5253 LIBVALKEY_AVAILABLE ,
6667
6768DEFAULT_RESP_VERSION = 2
6869
69- SENTINEL = object ()
70-
71- DefaultParser : Type [Union [_RESP2Parser , _RESP3Parser , _LibvalkeyParser ]]
70+ DefaultParser : type [Union [_RESP2Parser , _RESP3Parser , _LibvalkeyParser ]]
7271if LIBVALKEY_AVAILABLE :
7372 DefaultParser = _LibvalkeyParser
7473else :
7574 DefaultParser = _RESP2Parser
7675
7776
7877class LibvalkeyRespSerializer :
79- def pack (self , * args : List ) :
78+ def pack (self , * args ) -> list [ bytes ] :
8079 """Pack a series of arguments into the Valkey protocol"""
81- output = []
80+ output : list [ bytes ] = []
8281
8382 if isinstance (args [0 ], str ):
8483 args = tuple (args [0 ].encode ().split ()) + args [1 :]
@@ -98,7 +97,7 @@ def __init__(self, buffer_cutoff, encode) -> None:
9897 self ._buffer_cutoff = buffer_cutoff
9998 self .encode = encode
10099
101- def pack (self , * args ):
100+ def pack (self , * args ) -> list [ bytes ] :
102101 """Pack a series of arguments into the Valkey protocol"""
103102 output = []
104103 # the client might have included 1 or more literal arguments in
@@ -154,7 +153,7 @@ def __init__(
154153 socket_timeout : Optional [float ] = 5 ,
155154 socket_connect_timeout : Optional [float ] = None ,
156155 retry_on_timeout : bool = False ,
157- retry_on_error = SENTINEL ,
156+ retry_on_error : Optional [ list [ type [ Exception ]]] = None ,
158157 encoding : str = "utf-8" ,
159158 encoding_errors : str = "strict" ,
160159 decode_responses : bool = False ,
@@ -166,17 +165,17 @@ def __init__(
166165 lib_version : Optional [str ] = get_lib_version (),
167166 username : Optional [str ] = None ,
168167 retry : Union [Any , None ] = None ,
169- valkey_connect_func : Optional [Callable [[], None ]] = None ,
168+ valkey_connect_func : Optional [Callable [["AbstractConnection" ], None ]] = None ,
170169 credential_provider : Optional [CredentialProvider ] = None ,
171- protocol : Optional [int ] = 2 ,
170+ protocol : Optional [Union [ int , str ] ] = 2 ,
172171 command_packer : Optional [Callable [[], None ]] = None ,
173172 cache_enabled : bool = False ,
174173 client_cache : Optional [AbstractCache ] = None ,
175174 cache_max_size : int = 10000 ,
176175 cache_ttl : int = 0 ,
177- cache_policy : str = DEFAULT_EVICTION_POLICY ,
178- cache_deny_list : List [str ] = DEFAULT_DENY_LIST ,
179- cache_allow_list : List [str ] = DEFAULT_ALLOW_LIST ,
176+ cache_policy = DEFAULT_EVICTION_POLICY ,
177+ cache_deny_list : list [str ] = DEFAULT_DENY_LIST ,
178+ cache_allow_list : list [str ] = DEFAULT_ALLOW_LIST ,
180179 ):
181180 """
182181 Initialize a new Connection.
@@ -205,7 +204,7 @@ def __init__(
205204 socket_connect_timeout = socket_timeout
206205 self .socket_connect_timeout = socket_connect_timeout
207206 self .retry_on_timeout = retry_on_timeout
208- if retry_on_error is SENTINEL :
207+ if retry_on_error is None :
209208 retry_on_error = []
210209 if retry_on_timeout :
211210 # Add TimeoutError to the errors list to retry on
@@ -222,18 +221,16 @@ def __init__(
222221 else :
223222 self .retry = Retry (NoBackoff (), 0 )
224223 self .health_check_interval = health_check_interval
225- self .next_health_check = 0
224+ self .next_health_check = 0.0
226225 self .valkey_connect_func = valkey_connect_func
227226 self .encoder = Encoder (encoding , encoding_errors , decode_responses )
228- self ._sock = None
227+ self ._sock : Optional [ socket . socket ] = None
229228 self ._socket_read_size = socket_read_size
230229 self .set_parser (parser_class )
231- self ._connect_callbacks = []
230+ self ._connect_callbacks : list [ weakref . WeakMethod ] = []
232231 self ._buffer_cutoff = 6000
233232 try :
234- p = int (protocol )
235- except TypeError :
236- p = DEFAULT_RESP_VERSION
233+ p = int (protocol ) if protocol is not None else DEFAULT_RESP_VERSION
237234 except ValueError :
238235 raise ConnectionError ("protocol must be an integer" )
239236 finally :
@@ -260,7 +257,7 @@ def __repr__(self):
260257 return f"<{ self .__class__ .__module__ } .{ self .__class__ .__name__ } ({ repr_args } )>"
261258
262259 @abstractmethod
263- def repr_pieces (self ):
260+ def repr_pieces (self ) -> list [ tuple [ str , Any ]] :
264261 pass
265262
266263 def __del__ (self ):
@@ -310,7 +307,7 @@ def set_parser(self, parser_class):
310307
311308 def connect (self ):
312309 "Connects to the Valkey server if not already connected"
313- if self ._sock :
310+ if self ._sock is not None :
314311 return
315312 try :
316313 sock = self .retry .call_with_retry (
@@ -348,7 +345,7 @@ def _connect(self):
348345 pass
349346
350347 @abstractmethod
351- def _host_error (self ):
348+ def _host_error (self ) -> str :
352349 pass
353350
354351 def _error_message (self , exception ):
@@ -377,7 +374,7 @@ def on_connect(self):
377374 self ._parser .EXCEPTION_CLASSES = parser .EXCEPTION_CLASSES
378375 self ._parser .on_connect (self )
379376 if len (auth_args ) == 1 :
380- auth_args = [ "default" , auth_args [0 ]]
377+ auth_args = ( "default" , auth_args [0 ])
381378 self .send_command ("HELLO" , self .protocol , "AUTH" , * auth_args )
382379 response = self .read_response ()
383380 # if response.get(b"proto") != self.protocol and response.get(
@@ -492,6 +489,7 @@ def send_packed_command(self, command, check_health=True):
492489 """Send an already packed command to the Valkey server"""
493490 if not self ._sock :
494491 self .connect ()
492+ assert self ._sock is not None
495493 # guard against health check recursion
496494 if check_health :
497495 self .check_health ()
@@ -592,8 +590,8 @@ def pack_command(self, *args):
592590
593591 def pack_commands (self , commands ):
594592 """Pack multiple commands into the Valkey protocol"""
595- output = []
596- pieces = []
593+ output : list [ bytes ] = []
594+ pieces : list [ bytes ] = []
597595 buffer_length = 0
598596 buffer_cutoff = self ._buffer_cutoff
599597
@@ -621,14 +619,15 @@ def pack_commands(self, commands):
621619 return output
622620
623621 def _cache_invalidation_process (
624- self , data : List [Union [str , Optional [List [str ]]]]
622+ self , data : list [Union [str , Optional [list [str ]]]]
625623 ) -> None :
626624 """
627625 Invalidate (delete) all valkey commands associated with a specific key.
628626 `data` is a list of strings, where the first string is the invalidation message
629627 and the second string is the list of keys to invalidate.
630628 (if the list of keys is None, then all keys are invalidated)
631629 """
630+ assert self .client_cache is not None
632631 if data [1 ] is None :
633632 self .client_cache .flush ()
634633 else :
@@ -650,7 +649,7 @@ def _get_from_local_cache(self, command: Sequence[str]):
650649 return self .client_cache .get (command )
651650
652651 def _add_to_local_cache (
653- self , command : Sequence [str ], response : ResponseT , keys : List [ KeysT ]
652+ self , command : Sequence [str ], response : ResponseT , keys : list [ KeyT ]
654653 ):
655654 """
656655 Add the command and response to the local cache if the command
@@ -671,7 +670,7 @@ def delete_command_from_cache(self, command: Union[str, Sequence[str]]):
671670 if self .client_cache :
672671 self .client_cache .delete_command (command )
673672
674- def invalidate_key_from_cache (self , key : KeysT ):
673+ def invalidate_key_from_cache (self , key : KeyT ):
675674 if self .client_cache :
676675 self .client_cache .invalidate_key (key )
677676
@@ -1036,7 +1035,7 @@ def __init__(
10361035 self ._fork_lock = threading .Lock ()
10371036 self .reset ()
10381037
1039- def __repr__ (self ) -> ( str , str ) :
1038+ def __repr__ (self ) -> str :
10401039 return (
10411040 f"<{ type (self ).__module__ } .{ type (self ).__name__ } "
10421041 f"({ repr (self .connection_class (** self .connection_kwargs ))} )>"
@@ -1045,8 +1044,8 @@ def __repr__(self) -> (str, str):
10451044 def reset (self ) -> None :
10461045 self ._lock = threading .Lock ()
10471046 self ._created_connections = 0
1048- self ._available_connections = []
1049- self ._in_use_connections = set ()
1047+ self ._available_connections : list [ Connection ] = []
1048+ self ._in_use_connections : set [ Connection ] = set ()
10501049
10511050 # this must be the last operation in this method. while reset() is
10521051 # called when holding _fork_lock, other threads in this process
@@ -1193,7 +1192,7 @@ def disconnect(self, inuse_connections: bool = True) -> None:
11931192 self ._checkpid ()
11941193 with self ._lock :
11951194 if inuse_connections :
1196- connections = chain (
1195+ connections : Iterable [ Connection ] = chain (
11971196 self ._available_connections , self ._in_use_connections
11981197 )
11991198 else :
@@ -1388,7 +1387,7 @@ def release(self, connection):
13881387 # we don't want this connection
13891388 pass
13901389
1391- def disconnect (self ) :
1390+ def disconnect (self , _ : bool = True ) -> None :
13921391 "Disconnects all connections in the pool."
13931392 self ._checkpid ()
13941393 for connection in self ._connections :
0 commit comments