2323from os import getpid
2424from queue import Empty , Full , LifoQueue
2525from time import time
26- from typing import Any , Callable , List , Optional , Sequence , Type , Union
26+ from typing import Any , Callable , Iterable , List , Optional , Sequence , Tuple , Type , Union
2727
2828from ._cache import (
2929 DEFAULT_ALLOW_LIST ,
4646 ValkeyError ,
4747)
4848from .retry import Retry
49- from .typing import KeysT , ResponseT
49+ from .typing import KeyT , ResponseT
5050from .utils import (
5151 CRYPTOGRAPHY_AVAILABLE ,
5252 LIBVALKEY_AVAILABLE ,
6666
6767DEFAULT_RESP_VERSION = 2
6868
69- SENTINEL = object ()
70-
7169DefaultParser : Type [Union [_RESP2Parser , _RESP3Parser , _LibvalkeyParser ]]
7270if LIBVALKEY_AVAILABLE :
7371 DefaultParser = _LibvalkeyParser
7674
7775
7876class LibvalkeyRespSerializer :
79- def pack (self , * args : List ) :
77+ def pack (self , * args ) -> List [ bytes ] :
8078 """Pack a series of arguments into the Valkey protocol"""
81- output = []
79+ output : List [ bytes ] = []
8280
8381 if isinstance (args [0 ], str ):
8482 args = tuple (args [0 ].encode ().split ()) + args [1 :]
@@ -98,7 +96,7 @@ def __init__(self, buffer_cutoff, encode) -> None:
9896 self ._buffer_cutoff = buffer_cutoff
9997 self .encode = encode
10098
101- def pack (self , * args ):
99+ def pack (self , * args ) -> List [ bytes ] :
102100 """Pack a series of arguments into the Valkey protocol"""
103101 output = []
104102 # the client might have included 1 or more literal arguments in
@@ -154,7 +152,7 @@ def __init__(
154152 socket_timeout : Optional [float ] = 5 ,
155153 socket_connect_timeout : Optional [float ] = None ,
156154 retry_on_timeout : bool = False ,
157- retry_on_error = SENTINEL ,
155+ retry_on_error : Optional [ List [ Type [ Exception ]]] = None ,
158156 encoding : str = "utf-8" ,
159157 encoding_errors : str = "strict" ,
160158 decode_responses : bool = False ,
@@ -166,15 +164,15 @@ def __init__(
166164 lib_version : Optional [str ] = get_lib_version (),
167165 username : Optional [str ] = None ,
168166 retry : Union [Any , None ] = None ,
169- valkey_connect_func : Optional [Callable [[], None ]] = None ,
167+ valkey_connect_func : Optional [Callable [["AbstractConnection" ], None ]] = None ,
170168 credential_provider : Optional [CredentialProvider ] = None ,
171- protocol : Optional [int ] = 2 ,
169+ protocol : Optional [Union [ int , str ] ] = 2 ,
172170 command_packer : Optional [Callable [[], None ]] = None ,
173171 cache_enabled : bool = False ,
174172 client_cache : Optional [AbstractCache ] = None ,
175173 cache_max_size : int = 10000 ,
176174 cache_ttl : int = 0 ,
177- cache_policy : str = DEFAULT_EVICTION_POLICY ,
175+ cache_policy = DEFAULT_EVICTION_POLICY ,
178176 cache_deny_list : List [str ] = DEFAULT_DENY_LIST ,
179177 cache_allow_list : List [str ] = DEFAULT_ALLOW_LIST ,
180178 ):
@@ -205,7 +203,7 @@ def __init__(
205203 socket_connect_timeout = socket_timeout
206204 self .socket_connect_timeout = socket_connect_timeout
207205 self .retry_on_timeout = retry_on_timeout
208- if retry_on_error is SENTINEL :
206+ if retry_on_error is None :
209207 retry_on_error = []
210208 if retry_on_timeout :
211209 # Add TimeoutError to the errors list to retry on
@@ -222,18 +220,16 @@ def __init__(
222220 else :
223221 self .retry = Retry (NoBackoff (), 0 )
224222 self .health_check_interval = health_check_interval
225- self .next_health_check = 0
223+ self .next_health_check = 0.0
226224 self .valkey_connect_func = valkey_connect_func
227225 self .encoder = Encoder (encoding , encoding_errors , decode_responses )
228- self ._sock = None
226+ self ._sock : Optional [ socket . socket ] = None
229227 self ._socket_read_size = socket_read_size
230228 self .set_parser (parser_class )
231- self ._connect_callbacks = []
229+ self ._connect_callbacks : List [ weakref . WeakMethod ] = []
232230 self ._buffer_cutoff = 6000
233231 try :
234- p = int (protocol )
235- except TypeError :
236- p = DEFAULT_RESP_VERSION
232+ p = int (protocol ) if protocol is not None else DEFAULT_RESP_VERSION
237233 except ValueError :
238234 raise ConnectionError ("protocol must be an integer" )
239235 finally :
@@ -260,7 +256,7 @@ def __repr__(self):
260256 return f"<{ self .__class__ .__module__ } .{ self .__class__ .__name__ } ({ repr_args } )>"
261257
262258 @abstractmethod
263- def repr_pieces (self ):
259+ def repr_pieces (self ) -> List [ Tuple [ str , Any ]] :
264260 pass
265261
266262 def __del__ (self ):
@@ -310,7 +306,7 @@ def set_parser(self, parser_class):
310306
311307 def connect (self ):
312308 "Connects to the Valkey server if not already connected"
313- if self ._sock :
309+ if self ._sock is not None :
314310 return
315311 try :
316312 sock = self .retry .call_with_retry (
@@ -348,7 +344,7 @@ def _connect(self):
348344 pass
349345
350346 @abstractmethod
351- def _host_error (self ):
347+ def _host_error (self ) -> str :
352348 pass
353349
354350 def _error_message (self , exception ):
@@ -377,7 +373,7 @@ def on_connect(self):
377373 self ._parser .EXCEPTION_CLASSES = parser .EXCEPTION_CLASSES
378374 self ._parser .on_connect (self )
379375 if len (auth_args ) == 1 :
380- auth_args = [ "default" , auth_args [0 ]]
376+ auth_args = ( "default" , auth_args [0 ])
381377 self .send_command ("HELLO" , self .protocol , "AUTH" , * auth_args )
382378 response = self .read_response ()
383379 # if response.get(b"proto") != self.protocol and response.get(
@@ -492,6 +488,7 @@ def send_packed_command(self, command, check_health=True):
492488 """Send an already packed command to the Valkey server"""
493489 if not self ._sock :
494490 self .connect ()
491+ assert self ._sock is not None
495492 # guard against health check recursion
496493 if check_health :
497494 self .check_health ()
@@ -592,8 +589,8 @@ def pack_command(self, *args):
592589
593590 def pack_commands (self , commands ):
594591 """Pack multiple commands into the Valkey protocol"""
595- output = []
596- pieces = []
592+ output : List [ bytes ] = []
593+ pieces : List [ bytes ] = []
597594 buffer_length = 0
598595 buffer_cutoff = self ._buffer_cutoff
599596
@@ -629,6 +626,7 @@ def _cache_invalidation_process(
629626 and the second string is the list of keys to invalidate.
630627 (if the list of keys is None, then all keys are invalidated)
631628 """
629+ assert self .client_cache is not None
632630 if data [1 ] is None :
633631 self .client_cache .flush ()
634632 else :
@@ -650,7 +648,7 @@ def _get_from_local_cache(self, command: Sequence[str]):
650648 return self .client_cache .get (command )
651649
652650 def _add_to_local_cache (
653- self , command : Sequence [str ], response : ResponseT , keys : List [KeysT ]
651+ self , command : Sequence [str ], response : ResponseT , keys : List [KeyT ]
654652 ):
655653 """
656654 Add the command and response to the local cache if the command
@@ -671,7 +669,7 @@ def delete_command_from_cache(self, command: Union[str, Sequence[str]]):
671669 if self .client_cache :
672670 self .client_cache .delete_command (command )
673671
674- def invalidate_key_from_cache (self , key : KeysT ):
672+ def invalidate_key_from_cache (self , key : KeyT ):
675673 if self .client_cache :
676674 self .client_cache .invalidate_key (key )
677675
@@ -1036,7 +1034,7 @@ def __init__(
10361034 self ._fork_lock = threading .Lock ()
10371035 self .reset ()
10381036
1039- def __repr__ (self ) -> ( str , str ) :
1037+ def __repr__ (self ) -> str :
10401038 return (
10411039 f"<{ type (self ).__module__ } .{ type (self ).__name__ } "
10421040 f"({ repr (self .connection_class (** self .connection_kwargs ))} )>"
@@ -1045,8 +1043,8 @@ def __repr__(self) -> (str, str):
10451043 def reset (self ) -> None :
10461044 self ._lock = threading .Lock ()
10471045 self ._created_connections = 0
1048- self ._available_connections = []
1049- self ._in_use_connections = set ()
1046+ self ._available_connections : list [ Connection ] = []
1047+ self ._in_use_connections : set [ Connection ] = set ()
10501048
10511049 # this must be the last operation in this method. while reset() is
10521050 # called when holding _fork_lock, other threads in this process
@@ -1193,7 +1191,7 @@ def disconnect(self, inuse_connections: bool = True) -> None:
11931191 self ._checkpid ()
11941192 with self ._lock :
11951193 if inuse_connections :
1196- connections = chain (
1194+ connections : Iterable [ Connection ] = chain (
11971195 self ._available_connections , self ._in_use_connections
11981196 )
11991197 else :
@@ -1388,7 +1386,7 @@ def release(self, connection):
13881386 # we don't want this connection
13891387 pass
13901388
1391- def disconnect (self ) :
1389+ def disconnect (self , _ : bool = True ) -> None :
13921390 "Disconnects all connections in the pool."
13931391 self ._checkpid ()
13941392 for connection in self ._connections :
0 commit comments