11import abc
22import socket
33from time import sleep
4- from typing import TYPE_CHECKING , Any , Callable , Iterable , Tuple , Type , TypeVar , Union
4+ from typing import TYPE_CHECKING , Any , Callable , Generic , Iterable , Tuple , Type , TypeVar , Union
55
66from redis .exceptions import ConnectionError , TimeoutError
77
88T = TypeVar ("T" )
9+ E = TypeVar ("E" , bound = Exception , covariant = True )
910
1011if TYPE_CHECKING :
1112 from redis .backoff import AbstractBackoff
1213
1314
14- class AbstractRetry (abc .ABC ):
15+ class AbstractRetry (Generic [ E ], abc .ABC ):
1516 """Retry a specific number of times after a failure"""
1617
17- _supported_errors : Tuple [Type [Exception ], ...]
18+ _supported_errors : Tuple [Type [E ], ...]
1819
1920 def __init__ (
2021 self ,
2122 backoff : "AbstractBackoff" ,
2223 retries : int ,
23- supported_errors : Union [ Tuple [Type [Exception ], ...], None ] = None ,
24+ supported_errors : Tuple [Type [E ], ...],
2425 ):
2526 """
2627 Initialize a `Retry` object with a `Backoff` object
@@ -31,8 +32,7 @@ def __init__(
3132 """
3233 self ._backoff = backoff
3334 self ._retries = retries
34- if supported_errors :
35- self ._supported_errors = supported_errors
35+ self ._supported_errors = supported_errors
3636
3737 @abc .abstractmethod
3838 def __eq__ (self , other : Any ) -> bool :
@@ -42,7 +42,7 @@ def __hash__(self) -> int:
4242 return hash ((self ._backoff , self ._retries , frozenset (self ._supported_errors )))
4343
4444 def update_supported_errors (
45- self , specified_errors : Iterable [Type [Exception ]]
45+ self , specified_errors : Iterable [Type [E ]]
4646 ) -> None :
4747 """
4848 Updates the supported errors with the specified error types
@@ -64,14 +64,21 @@ def update_retries(self, value: int) -> None:
6464 self ._retries = value
6565
6666
67- class Retry (AbstractRetry ):
68- _supported_errors : Tuple [Type [Exception ], ...] = (
69- ConnectionError ,
70- TimeoutError ,
71- socket .timeout ,
72- )
67+ class Retry (AbstractRetry [Exception ]):
7368 __hash__ = AbstractRetry .__hash__
7469
70+ def __init__ (
71+ self ,
72+ backoff : "AbstractBackoff" ,
73+ retries : int ,
74+ supported_errors : Tuple [Type [Exception ], ...] = (
75+ ConnectionError , TimeoutError , socket .timeout
76+ ),
77+ ):
78+ super ().__init__ (backoff , retries , supported_errors )
79+
80+ __init__ .__doc__ = AbstractRetry .__init__ .__doc__
81+
7582 def __eq__ (self , other : Any ) -> bool :
7683 if not isinstance (other , Retry ):
7784 return NotImplemented
0 commit comments