Skip to content

Commit 882cbeb

Browse files
author
Mikhail Koviazin
committed
Improve typing hints in valkey/connection.py
Signed-off-by: Mikhail Koviazin <mikhail.koviazin@aiven.io>
1 parent e9104ed commit 882cbeb

File tree

3 files changed

+31
-33
lines changed

3 files changed

+31
-33
lines changed

valkey/_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from valkey.typing import KeyT, ResponseT
1010

1111

12-
class EvictionPolicy(Enum):
12+
class EvictionPolicy(str, Enum):
1313
LRU = "lru"
1414
LFU = "lfu"
1515
RANDOM = "random"

valkey/connection.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from os import getpid
2424
from queue import Empty, Full, LifoQueue
2525
from time import time
26-
from typing import Any, Callable, List, Optional, Sequence, Type, Union
26+
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type, Union
2727

2828
from ._cache import (
2929
DEFAULT_ALLOW_LIST,
@@ -46,7 +46,7 @@
4646
ValkeyError,
4747
)
4848
from .retry import Retry
49-
from .typing import KeysT, ResponseT
49+
from .typing import KeyT, ResponseT
5050
from .utils import (
5151
CRYPTOGRAPHY_AVAILABLE,
5252
LIBVALKEY_AVAILABLE,
@@ -66,8 +66,6 @@
6666

6767
DEFAULT_RESP_VERSION = 2
6868

69-
SENTINEL = object()
70-
7169
DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _LibvalkeyParser]]
7270
if LIBVALKEY_AVAILABLE:
7371
DefaultParser = _LibvalkeyParser
@@ -76,9 +74,9 @@
7674

7775

7876
class LibvalkeyRespSerializer:
79-
def pack(self, *args: List):
77+
def pack(self, *args) -> List[bytes]:
8078
"""Pack a series of arguments into the Valkey protocol"""
81-
output = []
79+
output: List[bytes] = []
8280

8381
if isinstance(args[0], str):
8482
args = tuple(args[0].encode().split()) + args[1:]
@@ -98,7 +96,7 @@ def __init__(self, buffer_cutoff, encode) -> None:
9896
self._buffer_cutoff = buffer_cutoff
9997
self.encode = encode
10098

101-
def pack(self, *args):
99+
def pack(self, *args) -> List[bytes]:
102100
"""Pack a series of arguments into the Valkey protocol"""
103101
output = []
104102
# the client might have included 1 or more literal arguments in
@@ -154,7 +152,7 @@ def __init__(
154152
socket_timeout: Optional[float] = 5,
155153
socket_connect_timeout: Optional[float] = None,
156154
retry_on_timeout: bool = False,
157-
retry_on_error=SENTINEL,
155+
retry_on_error: Optional[List[Type[Exception]]] = None,
158156
encoding: str = "utf-8",
159157
encoding_errors: str = "strict",
160158
decode_responses: bool = False,
@@ -166,15 +164,15 @@ def __init__(
166164
lib_version: Optional[str] = get_lib_version(),
167165
username: Optional[str] = None,
168166
retry: Union[Any, None] = None,
169-
valkey_connect_func: Optional[Callable[[], None]] = None,
167+
valkey_connect_func: Optional[Callable[["AbstractConnection"], None]] = None,
170168
credential_provider: Optional[CredentialProvider] = None,
171-
protocol: Optional[int] = 2,
169+
protocol: Optional[Union[int, str]] = 2,
172170
command_packer: Optional[Callable[[], None]] = None,
173171
cache_enabled: bool = False,
174172
client_cache: Optional[AbstractCache] = None,
175173
cache_max_size: int = 10000,
176174
cache_ttl: int = 0,
177-
cache_policy: str = DEFAULT_EVICTION_POLICY,
175+
cache_policy=DEFAULT_EVICTION_POLICY,
178176
cache_deny_list: List[str] = DEFAULT_DENY_LIST,
179177
cache_allow_list: List[str] = DEFAULT_ALLOW_LIST,
180178
):
@@ -205,7 +203,7 @@ def __init__(
205203
socket_connect_timeout = socket_timeout
206204
self.socket_connect_timeout = socket_connect_timeout
207205
self.retry_on_timeout = retry_on_timeout
208-
if retry_on_error is SENTINEL:
206+
if retry_on_error is None:
209207
retry_on_error = []
210208
if retry_on_timeout:
211209
# Add TimeoutError to the errors list to retry on
@@ -222,18 +220,16 @@ def __init__(
222220
else:
223221
self.retry = Retry(NoBackoff(), 0)
224222
self.health_check_interval = health_check_interval
225-
self.next_health_check = 0
223+
self.next_health_check = 0.0
226224
self.valkey_connect_func = valkey_connect_func
227225
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
228-
self._sock = None
226+
self._sock: Optional[socket.socket] = None
229227
self._socket_read_size = socket_read_size
230228
self.set_parser(parser_class)
231-
self._connect_callbacks = []
229+
self._connect_callbacks: List[weakref.WeakMethod] = []
232230
self._buffer_cutoff = 6000
233231
try:
234-
p = int(protocol)
235-
except TypeError:
236-
p = DEFAULT_RESP_VERSION
232+
p = int(protocol) if protocol is not None else DEFAULT_RESP_VERSION
237233
except ValueError:
238234
raise ConnectionError("protocol must be an integer")
239235
finally:
@@ -260,7 +256,7 @@ def __repr__(self):
260256
return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
261257

262258
@abstractmethod
263-
def repr_pieces(self):
259+
def repr_pieces(self) -> List[Tuple[str, Any]]:
264260
pass
265261

266262
def __del__(self):
@@ -310,7 +306,7 @@ def set_parser(self, parser_class):
310306

311307
def connect(self):
312308
"Connects to the Valkey server if not already connected"
313-
if self._sock:
309+
if self._sock is not None:
314310
return
315311
try:
316312
sock = self.retry.call_with_retry(
@@ -348,7 +344,7 @@ def _connect(self):
348344
pass
349345

350346
@abstractmethod
351-
def _host_error(self):
347+
def _host_error(self) -> str:
352348
pass
353349

354350
def _error_message(self, exception):
@@ -377,7 +373,7 @@ def on_connect(self):
377373
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
378374
self._parser.on_connect(self)
379375
if len(auth_args) == 1:
380-
auth_args = ["default", auth_args[0]]
376+
auth_args = ("default", auth_args[0])
381377
self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
382378
response = self.read_response()
383379
# if response.get(b"proto") != self.protocol and response.get(
@@ -492,6 +488,7 @@ def send_packed_command(self, command, check_health=True):
492488
"""Send an already packed command to the Valkey server"""
493489
if not self._sock:
494490
self.connect()
491+
assert self._sock is not None
495492
# guard against health check recursion
496493
if check_health:
497494
self.check_health()
@@ -592,8 +589,8 @@ def pack_command(self, *args):
592589

593590
def pack_commands(self, commands):
594591
"""Pack multiple commands into the Valkey protocol"""
595-
output = []
596-
pieces = []
592+
output: List[bytes] = []
593+
pieces: List[bytes] = []
597594
buffer_length = 0
598595
buffer_cutoff = self._buffer_cutoff
599596

@@ -629,6 +626,7 @@ def _cache_invalidation_process(
629626
and the second string is the list of keys to invalidate.
630627
(if the list of keys is None, then all keys are invalidated)
631628
"""
629+
assert self.client_cache is not None
632630
if data[1] is None:
633631
self.client_cache.flush()
634632
else:
@@ -650,7 +648,7 @@ def _get_from_local_cache(self, command: Sequence[str]):
650648
return self.client_cache.get(command)
651649

652650
def _add_to_local_cache(
653-
self, command: Sequence[str], response: ResponseT, keys: List[KeysT]
651+
self, command: Sequence[str], response: ResponseT, keys: List[KeyT]
654652
):
655653
"""
656654
Add the command and response to the local cache if the command
@@ -671,7 +669,7 @@ def delete_command_from_cache(self, command: Union[str, Sequence[str]]):
671669
if self.client_cache:
672670
self.client_cache.delete_command(command)
673671

674-
def invalidate_key_from_cache(self, key: KeysT):
672+
def invalidate_key_from_cache(self, key: KeyT):
675673
if self.client_cache:
676674
self.client_cache.invalidate_key(key)
677675

@@ -1036,7 +1034,7 @@ def __init__(
10361034
self._fork_lock = threading.Lock()
10371035
self.reset()
10381036

1039-
def __repr__(self) -> (str, str):
1037+
def __repr__(self) -> str:
10401038
return (
10411039
f"<{type(self).__module__}.{type(self).__name__}"
10421040
f"({repr(self.connection_class(**self.connection_kwargs))})>"
@@ -1045,8 +1043,8 @@ def __repr__(self) -> (str, str):
10451043
def reset(self) -> None:
10461044
self._lock = threading.Lock()
10471045
self._created_connections = 0
1048-
self._available_connections = []
1049-
self._in_use_connections = set()
1046+
self._available_connections: list[Connection] = []
1047+
self._in_use_connections: set[Connection] = set()
10501048

10511049
# this must be the last operation in this method. while reset() is
10521050
# called when holding _fork_lock, other threads in this process
@@ -1193,7 +1191,7 @@ def disconnect(self, inuse_connections: bool = True) -> None:
11931191
self._checkpid()
11941192
with self._lock:
11951193
if inuse_connections:
1196-
connections = chain(
1194+
connections: Iterable[Connection] = chain(
11971195
self._available_connections, self._in_use_connections
11981196
)
11991197
else:
@@ -1388,7 +1386,7 @@ def release(self, connection):
13881386
# we don't want this connection
13891387
pass
13901388

1391-
def disconnect(self):
1389+
def disconnect(self, _: bool = True) -> None:
13921390
"Disconnects all connections in the pool."
13931391
self._checkpid()
13941392
for connection in self._connections:

valkey/retry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
T = TypeVar("T")
88

99
if TYPE_CHECKING:
10-
from redis.backoff import AbstractBackoff
10+
from valkey.backoff import AbstractBackoff
1111

1212

1313
class Retry:

0 commit comments

Comments
 (0)