44import time
55import warnings
66from itertools import chain
7- from typing import Optional , Type
7+ from typing import Any , Callable , Dict , List , Optional , Type , Union
88
9+ from redis ._parsers .encoders import Encoder
910from redis ._parsers .helpers import (
1011 _RedisCallbacks ,
1112 _RedisCallbacksRESP2 ,
4950class CaseInsensitiveDict (dict ):
5051 "Case insensitive dict implementation. Assumes string keys only."
5152
52- def __init__ (self , data ) :
53+ def __init__ (self , data : Dict [ str , str ]) -> None :
5354 for k , v in data .items ():
5455 self [k .upper ()] = v
5556
@@ -93,7 +94,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
9394 """
9495
9596 @classmethod
96- def from_url (cls , url , ** kwargs ):
97+ def from_url (cls , url : str , ** kwargs ) -> None :
9798 """
9899 Return a Redis client object configured from the given URL
99100
@@ -202,7 +203,7 @@ def __init__(
202203 redis_connect_func = None ,
203204 credential_provider : Optional [CredentialProvider ] = None ,
204205 protocol : Optional [int ] = 2 ,
205- ):
206+ ) -> None :
206207 """
207208 Initialize a new Redis client.
208209 To specify a retry policy for specific errors, first set
@@ -309,14 +310,14 @@ def __init__(
309310 else :
310311 self .response_callbacks .update (_RedisCallbacksRESP2 )
311312
312- def __repr__ (self ):
313+ def __repr__ (self ) -> str :
313314 return f"{ type (self ).__name__ } <{ repr (self .connection_pool )} >"
314315
315- def get_encoder (self ):
316+ def get_encoder (self ) -> "Encoder" :
316317 """Get the connection pool's encoder"""
317318 return self .connection_pool .get_encoder ()
318319
319- def get_connection_kwargs (self ):
320+ def get_connection_kwargs (self ) -> Dict :
320321 """Get the connection's key-word arguments"""
321322 return self .connection_pool .connection_kwargs
322323
@@ -327,11 +328,11 @@ def set_retry(self, retry: "Retry") -> None:
327328 self .get_connection_kwargs ().update ({"retry" : retry })
328329 self .connection_pool .set_retry (retry )
329330
330- def set_response_callback (self , command , callback ) :
331+ def set_response_callback (self , command : str , callback : Callable ) -> None :
331332 """Set a custom Response Callback"""
332333 self .response_callbacks [command ] = callback
333334
334- def load_external_module (self , funcname , func ):
335+ def load_external_module (self , funcname , func ) -> None :
335336 """
336337 This function can be used to add externally defined redis modules,
337338 and their namespaces to the redis client.
@@ -354,7 +355,7 @@ def load_external_module(self, funcname, func):
354355 """
355356 setattr (self , funcname , func )
356357
357- def pipeline (self , transaction = True , shard_hint = None ):
358+ def pipeline (self , transaction = True , shard_hint = None ) -> "Pipeline" :
358359 """
359360 Return a new pipeline object that can queue multiple commands for
360361 later execution. ``transaction`` indicates whether all commands
@@ -366,7 +367,9 @@ def pipeline(self, transaction=True, shard_hint=None):
366367 self .connection_pool , self .response_callbacks , transaction , shard_hint
367368 )
368369
369- def transaction (self , func , * watches , ** kwargs ):
370+ def transaction (
371+ self , func : Callable [["Pipeline" ], None ], * watches , ** kwargs
372+ ) -> None :
370373 """
371374 Convenience method for executing the callable `func` as a transaction
372375 while watching all keys specified in `watches`. The 'func' callable
@@ -390,13 +393,13 @@ def transaction(self, func, *watches, **kwargs):
390393
391394 def lock (
392395 self ,
393- name ,
394- timeout = None ,
395- sleep = 0.1 ,
396- blocking = True ,
397- blocking_timeout = None ,
398- lock_class = None ,
399- thread_local = True ,
396+ name : str ,
397+ timeout : Optional [ float ] = None ,
398+ sleep : float = 0.1 ,
399+ blocking : bool = True ,
400+ blocking_timeout : Optional [ float ] = None ,
401+ lock_class : Union [ None , Any ] = None ,
402+ thread_local : bool = True ,
400403 ):
401404 """
402405 Return a new Lock object using key ``name`` that mimics
@@ -648,9 +651,9 @@ def __init__(
648651 self ,
649652 connection_pool ,
650653 shard_hint = None ,
651- ignore_subscribe_messages = False ,
652- encoder = None ,
653- push_handler_func = None ,
654+ ignore_subscribe_messages : bool = False ,
655+ encoder : Optional [ "Encoder" ] = None ,
656+ push_handler_func : Union [ None , Callable [[ str ], None ]] = None ,
654657 ):
655658 self .connection_pool = connection_pool
656659 self .shard_hint = shard_hint
@@ -672,13 +675,13 @@ def __init__(
672675 _set_info_logger ()
673676 self .reset ()
674677
675- def __enter__ (self ):
678+ def __enter__ (self ) -> "PubSub" :
676679 return self
677680
678- def __exit__ (self , exc_type , exc_value , traceback ):
681+ def __exit__ (self , exc_type , exc_value , traceback ) -> None :
679682 self .reset ()
680683
681- def __del__ (self ):
684+ def __del__ (self ) -> None :
682685 try :
683686 # if this object went out of scope prior to shutting down
684687 # subscriptions, close the connection manually before
@@ -687,7 +690,7 @@ def __del__(self):
687690 except Exception :
688691 pass
689692
690- def reset (self ):
693+ def reset (self ) -> None :
691694 if self .connection :
692695 self .connection .disconnect ()
693696 self .connection ._deregister_connect_callback (self .on_connect )
@@ -702,10 +705,10 @@ def reset(self):
702705 self .pending_unsubscribe_patterns = set ()
703706 self .subscribed_event .clear ()
704707
705- def close (self ):
708+ def close (self ) -> None :
706709 self .reset ()
707710
708- def on_connect (self , connection ):
711+ def on_connect (self , connection ) -> None :
709712 "Re-subscribe to any channels and patterns previously subscribed to"
710713 # NOTE: for python3, we can't pass bytestrings as keyword arguments
711714 # so we need to decode channel/pattern names back to unicode strings
@@ -731,7 +734,7 @@ def on_connect(self, connection):
731734 self .ssubscribe (** shard_channels )
732735
733736 @property
734- def subscribed (self ):
737+ def subscribed (self ) -> bool :
735738 """Indicates if there are subscriptions to any channels or patterns"""
736739 return self .subscribed_event .is_set ()
737740
@@ -757,7 +760,7 @@ def execute_command(self, *args):
757760 self .clean_health_check_responses ()
758761 self ._execute (connection , connection .send_command , * args , ** kwargs )
759762
760- def clean_health_check_responses (self ):
763+ def clean_health_check_responses (self ) -> None :
761764 """
762765 If any health check responses are present, clean them
763766 """
@@ -775,7 +778,7 @@ def clean_health_check_responses(self):
775778 )
776779 ttl -= 1
777780
778- def _disconnect_raise_connect (self , conn , error ):
781+ def _disconnect_raise_connect (self , conn , error ) -> None :
779782 """
780783 Close the connection and raise an exception
781784 if retry_on_timeout is not set or the error
@@ -826,7 +829,7 @@ def try_read():
826829 return None
827830 return response
828831
829- def is_health_check_response (self , response ):
832+ def is_health_check_response (self , response ) -> bool :
830833 """
831834 Check if the response is a health check response.
832835 If there are no subscriptions redis responds to PING command with a
@@ -837,7 +840,7 @@ def is_health_check_response(self, response):
837840 self .health_check_response_b , # If there wasn't
838841 ]
839842
840- def check_health (self ):
843+ def check_health (self ) -> None :
841844 conn = self .connection
842845 if conn is None :
843846 raise RuntimeError (
@@ -849,7 +852,7 @@ def check_health(self):
849852 conn .send_command ("PING" , self .HEALTH_CHECK_MESSAGE , check_health = False )
850853 self .health_check_response_counter += 1
851854
852- def _normalize_keys (self , data ):
855+ def _normalize_keys (self , data ) -> Dict :
853856 """
854857 normalize channel/pattern names to be either bytes or strings
855858 based on whether responses are automatically decoded. this saves us
@@ -983,7 +986,9 @@ def listen(self):
983986 if response is not None :
984987 yield response
985988
986- def get_message (self , ignore_subscribe_messages = False , timeout = 0.0 ):
989+ def get_message (
990+ self , ignore_subscribe_messages : bool = False , timeout : float = 0.0
991+ ):
987992 """
988993 Get the next message if one is available, otherwise None.
989994
@@ -1012,7 +1017,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
10121017
10131018 get_sharded_message = get_message
10141019
1015- def ping (self , message = None ):
1020+ def ping (self , message : Union [ str , None ] = None ) -> bool :
10161021 """
10171022 Ping the Redis server
10181023 """
@@ -1093,7 +1098,12 @@ def handle_message(self, response, ignore_subscribe_messages=False):
10931098
10941099 return message
10951100
1096- def run_in_thread (self , sleep_time = 0 , daemon = False , exception_handler = None ):
1101+ def run_in_thread (
1102+ self ,
1103+ sleep_time : int = 0 ,
1104+ daemon : bool = False ,
1105+ exception_handler : Optional [Callable ] = None ,
1106+ ) -> "PubSubWorkerThread" :
10971107 for channel , handler in self .channels .items ():
10981108 if handler is None :
10991109 raise PubSubError (f"Channel: '{ channel } ' has no handler registered" )
@@ -1114,15 +1124,23 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):
11141124
11151125
11161126class PubSubWorkerThread (threading .Thread ):
1117- def __init__ (self , pubsub , sleep_time , daemon = False , exception_handler = None ):
1127+ def __init__ (
1128+ self ,
1129+ pubsub ,
1130+ sleep_time : float ,
1131+ daemon : bool = False ,
1132+ exception_handler : Union [
1133+ Callable [[Exception , "PubSub" , "PubSubWorkerThread" ], None ], None
1134+ ] = None ,
1135+ ):
11181136 super ().__init__ ()
11191137 self .daemon = daemon
11201138 self .pubsub = pubsub
11211139 self .sleep_time = sleep_time
11221140 self .exception_handler = exception_handler
11231141 self ._running = threading .Event ()
11241142
1125- def run (self ):
1143+ def run (self ) -> None :
11261144 if self ._running .is_set ():
11271145 return
11281146 self ._running .set ()
@@ -1137,7 +1155,7 @@ def run(self):
11371155 self .exception_handler (e , pubsub , self )
11381156 pubsub .close ()
11391157
1140- def stop (self ):
1158+ def stop (self ) -> None :
11411159 # trip the flag so the run loop exits. the run loop will
11421160 # close the pubsub connection, which disconnects the socket
11431161 # and returns the connection to the pool.
@@ -1175,7 +1193,7 @@ def __init__(self, connection_pool, response_callbacks, transaction, shard_hint)
11751193 self .watching = False
11761194 self .reset ()
11771195
1178- def __enter__ (self ):
1196+ def __enter__ (self ) -> "Pipeline" :
11791197 return self
11801198
11811199 def __exit__ (self , exc_type , exc_value , traceback ):
@@ -1187,14 +1205,14 @@ def __del__(self):
11871205 except Exception :
11881206 pass
11891207
1190- def __len__ (self ):
1208+ def __len__ (self ) -> int :
11911209 return len (self .command_stack )
11921210
1193- def __bool__ (self ):
1211+ def __bool__ (self ) -> bool :
11941212 """Pipeline instances should always evaluate to True"""
11951213 return True
11961214
1197- def reset (self ):
1215+ def reset (self ) -> None :
11981216 self .command_stack = []
11991217 self .scripts = set ()
12001218 # make sure to reset the connection state in the event that we were
@@ -1217,11 +1235,11 @@ def reset(self):
12171235 self .connection_pool .release (self .connection )
12181236 self .connection = None
12191237
1220- def close (self ):
1238+ def close (self ) -> None :
12211239 """Close the pipeline"""
12221240 self .reset ()
12231241
1224- def multi (self ):
1242+ def multi (self ) -> None :
12251243 """
12261244 Start a transactional block of the pipeline after WATCH commands
12271245 are issued. End the transactional block with `execute`.
@@ -1239,7 +1257,7 @@ def execute_command(self, *args, **kwargs):
12391257 return self .immediate_execute_command (* args , ** kwargs )
12401258 return self .pipeline_execute_command (* args , ** kwargs )
12411259
1242- def _disconnect_reset_raise (self , conn , error ):
1260+ def _disconnect_reset_raise (self , conn , error ) -> None :
12431261 """
12441262 Close the connection, reset watching state and
12451263 raise an exception if we were watching,
@@ -1282,7 +1300,7 @@ def immediate_execute_command(self, *args, **options):
12821300 lambda error : self ._disconnect_reset_raise (conn , error ),
12831301 )
12841302
1285- def pipeline_execute_command (self , * args , ** options ):
1303+ def pipeline_execute_command (self , * args , ** options ) -> "Pipeline" :
12861304 """
12871305 Stage a command to be executed when execute() is next called
12881306
@@ -1297,7 +1315,7 @@ def pipeline_execute_command(self, *args, **options):
12971315 self .command_stack .append ((args , options ))
12981316 return self
12991317
1300- def _execute_transaction (self , connection , commands , raise_on_error ):
1318+ def _execute_transaction (self , connection , commands , raise_on_error ) -> List :
13011319 cmds = chain ([(("MULTI" ,), {})], commands , [(("EXEC" ,), {})])
13021320 all_cmds = connection .pack_commands (
13031321 [args for args , options in cmds if EMPTY_RESPONSE not in options ]
@@ -1415,7 +1433,7 @@ def load_scripts(self):
14151433 if not exist :
14161434 s .sha = immediate ("SCRIPT LOAD" , s .script )
14171435
1418- def _disconnect_raise_reset (self , conn , error ) :
1436+ def _disconnect_raise_reset (self , conn : Redis , error : Exception ) -> None :
14191437 """
14201438 Close the connection, raise an exception if we were watching,
14211439 and raise an exception if TimeoutError is not part of retry_on_error,
@@ -1477,6 +1495,6 @@ def watch(self, *names):
14771495 raise RedisError ("Cannot issue a WATCH after a MULTI" )
14781496 return self .execute_command ("WATCH" , * names )
14791497
1480- def unwatch (self ):
1498+ def unwatch (self ) -> bool :
14811499 """Unwatches all previously specified keys"""
14821500 return self .watching and self .execute_command ("UNWATCH" ) or True
0 commit comments