11import abc
22import socket
33from time import sleep
4- from typing import TYPE_CHECKING , Any , Callable , Iterable , Tuple , Type , TypeVar , Union
4+ from typing import TYPE_CHECKING , Any , Callable , Generic , Iterable , Tuple , Type , TypeVar
55
66from redis .exceptions import ConnectionError , TimeoutError
77
88T = TypeVar ("T" )
9+ E = TypeVar ("E" , bound = Exception , covariant = True )
910
1011if TYPE_CHECKING :
1112 from redis .backoff import AbstractBackoff
1213
1314
14- class AbstractRetry (abc .ABC ):
15+ class AbstractRetry (Generic [ E ], abc .ABC ):
1516 """Retry a specific number of times after a failure"""
1617
17- _supported_errors : Tuple [Type [Exception ], ...]
18+ _supported_errors : Tuple [Type [E ], ...]
1819
1920 def __init__ (
2021 self ,
2122 backoff : "AbstractBackoff" ,
2223 retries : int ,
23- supported_errors : Union [ Tuple [Type [Exception ], ...], None ] = None ,
24+ supported_errors : Tuple [Type [E ], ...],
2425 ):
2526 """
2627 Initialize a `Retry` object with a `Backoff` object
@@ -31,8 +32,7 @@ def __init__(
3132 """
3233 self ._backoff = backoff
3334 self ._retries = retries
34- if supported_errors :
35- self ._supported_errors = supported_errors
35+ self ._supported_errors = supported_errors
3636
3737 @abc .abstractmethod
3838 def __eq__ (self , other : Any ) -> bool :
@@ -41,9 +41,7 @@ def __eq__(self, other: Any) -> bool:
4141 def __hash__ (self ) -> int :
4242 return hash ((self ._backoff , self ._retries , frozenset (self ._supported_errors )))
4343
44- def update_supported_errors (
45- self , specified_errors : Iterable [Type [Exception ]]
46- ) -> None :
44+ def update_supported_errors (self , specified_errors : Iterable [Type [E ]]) -> None :
4745 """
4846 Updates the supported errors with the specified error types
4947 """
@@ -64,14 +62,23 @@ def update_retries(self, value: int) -> None:
6462 self ._retries = value
6563
6664
67- class Retry (AbstractRetry ):
68- _supported_errors : Tuple [Type [Exception ], ...] = (
69- ConnectionError ,
70- TimeoutError ,
71- socket .timeout ,
72- )
65+ class Retry (AbstractRetry [Exception ]):
7366 __hash__ = AbstractRetry .__hash__
7467
68+ def __init__ (
69+ self ,
70+ backoff : "AbstractBackoff" ,
71+ retries : int ,
72+ supported_errors : Tuple [Type [Exception ], ...] = (
73+ ConnectionError ,
74+ TimeoutError ,
75+ socket .timeout ,
76+ ),
77+ ):
78+ super ().__init__ (backoff , retries , supported_errors )
79+
80+ __init__ .__doc__ = AbstractRetry .__init__ .__doc__
81+
7582 def __eq__ (self , other : Any ) -> bool :
7683 if not isinstance (other , Retry ):
7784 return NotImplemented
0 commit comments