Skip to content

Commit 495bf7e

Browse files
authored
Merge pull request #200 from valkey-io/mkmkme/fix-typing-connection
Improve typing hints in valkey/connection.py
2 parents e9104ed + 8188e5f commit 495bf7e

File tree

3 files changed

+36
-37
lines changed

3 files changed

+36
-37
lines changed

valkey/_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from valkey.typing import KeyT, ResponseT
1010

1111

12-
class EvictionPolicy(Enum):
12+
class EvictionPolicy(str, Enum):
1313
LRU = "lru"
1414
LFU = "lfu"
1515
RANDOM = "random"

valkey/connection.py

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import threading
66
import weakref
77
from abc import abstractmethod
8+
from collections.abc import Callable, Iterable, Sequence
89
from itertools import chain
910

1011
# We need to explicitly import `getpid` from `os` instead of importing `os`. The
@@ -23,7 +24,7 @@
2324
from os import getpid
2425
from queue import Empty, Full, LifoQueue
2526
from time import time
26-
from typing import Any, Callable, List, Optional, Sequence, Type, Union
27+
from typing import Any, Optional, Union
2728

2829
from ._cache import (
2930
DEFAULT_ALLOW_LIST,
@@ -46,7 +47,7 @@
4647
ValkeyError,
4748
)
4849
from .retry import Retry
49-
from .typing import KeysT, ResponseT
50+
from .typing import KeyT, ResponseT
5051
from .utils import (
5152
CRYPTOGRAPHY_AVAILABLE,
5253
LIBVALKEY_AVAILABLE,
@@ -66,19 +67,17 @@
6667

6768
DEFAULT_RESP_VERSION = 2
6869

69-
SENTINEL = object()
70-
71-
DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _LibvalkeyParser]]
70+
DefaultParser: type[Union[_RESP2Parser, _RESP3Parser, _LibvalkeyParser]]
7271
if LIBVALKEY_AVAILABLE:
7372
DefaultParser = _LibvalkeyParser
7473
else:
7574
DefaultParser = _RESP2Parser
7675

7776

7877
class LibvalkeyRespSerializer:
79-
def pack(self, *args: List):
78+
def pack(self, *args) -> list[bytes]:
8079
"""Pack a series of arguments into the Valkey protocol"""
81-
output = []
80+
output: list[bytes] = []
8281

8382
if isinstance(args[0], str):
8483
args = tuple(args[0].encode().split()) + args[1:]
@@ -98,7 +97,7 @@ def __init__(self, buffer_cutoff, encode) -> None:
9897
self._buffer_cutoff = buffer_cutoff
9998
self.encode = encode
10099

101-
def pack(self, *args):
100+
def pack(self, *args) -> list[bytes]:
102101
"""Pack a series of arguments into the Valkey protocol"""
103102
output = []
104103
# the client might have included 1 or more literal arguments in
@@ -154,7 +153,7 @@ def __init__(
154153
socket_timeout: Optional[float] = 5,
155154
socket_connect_timeout: Optional[float] = None,
156155
retry_on_timeout: bool = False,
157-
retry_on_error=SENTINEL,
156+
retry_on_error: Optional[list[type[Exception]]] = None,
158157
encoding: str = "utf-8",
159158
encoding_errors: str = "strict",
160159
decode_responses: bool = False,
@@ -166,17 +165,17 @@ def __init__(
166165
lib_version: Optional[str] = get_lib_version(),
167166
username: Optional[str] = None,
168167
retry: Union[Any, None] = None,
169-
valkey_connect_func: Optional[Callable[[], None]] = None,
168+
valkey_connect_func: Optional[Callable[["AbstractConnection"], None]] = None,
170169
credential_provider: Optional[CredentialProvider] = None,
171-
protocol: Optional[int] = 2,
170+
protocol: Optional[Union[int, str]] = 2,
172171
command_packer: Optional[Callable[[], None]] = None,
173172
cache_enabled: bool = False,
174173
client_cache: Optional[AbstractCache] = None,
175174
cache_max_size: int = 10000,
176175
cache_ttl: int = 0,
177-
cache_policy: str = DEFAULT_EVICTION_POLICY,
178-
cache_deny_list: List[str] = DEFAULT_DENY_LIST,
179-
cache_allow_list: List[str] = DEFAULT_ALLOW_LIST,
176+
cache_policy=DEFAULT_EVICTION_POLICY,
177+
cache_deny_list: list[str] = DEFAULT_DENY_LIST,
178+
cache_allow_list: list[str] = DEFAULT_ALLOW_LIST,
180179
):
181180
"""
182181
Initialize a new Connection.
@@ -205,7 +204,7 @@ def __init__(
205204
socket_connect_timeout = socket_timeout
206205
self.socket_connect_timeout = socket_connect_timeout
207206
self.retry_on_timeout = retry_on_timeout
208-
if retry_on_error is SENTINEL:
207+
if retry_on_error is None:
209208
retry_on_error = []
210209
if retry_on_timeout:
211210
# Add TimeoutError to the errors list to retry on
@@ -222,18 +221,16 @@ def __init__(
222221
else:
223222
self.retry = Retry(NoBackoff(), 0)
224223
self.health_check_interval = health_check_interval
225-
self.next_health_check = 0
224+
self.next_health_check = 0.0
226225
self.valkey_connect_func = valkey_connect_func
227226
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
228-
self._sock = None
227+
self._sock: Optional[socket.socket] = None
229228
self._socket_read_size = socket_read_size
230229
self.set_parser(parser_class)
231-
self._connect_callbacks = []
230+
self._connect_callbacks: list[weakref.WeakMethod] = []
232231
self._buffer_cutoff = 6000
233232
try:
234-
p = int(protocol)
235-
except TypeError:
236-
p = DEFAULT_RESP_VERSION
233+
p = int(protocol) if protocol is not None else DEFAULT_RESP_VERSION
237234
except ValueError:
238235
raise ConnectionError("protocol must be an integer")
239236
finally:
@@ -260,7 +257,7 @@ def __repr__(self):
260257
return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
261258

262259
@abstractmethod
263-
def repr_pieces(self):
260+
def repr_pieces(self) -> list[tuple[str, Any]]:
264261
pass
265262

266263
def __del__(self):
@@ -310,7 +307,7 @@ def set_parser(self, parser_class):
310307

311308
def connect(self):
312309
"Connects to the Valkey server if not already connected"
313-
if self._sock:
310+
if self._sock is not None:
314311
return
315312
try:
316313
sock = self.retry.call_with_retry(
@@ -348,7 +345,7 @@ def _connect(self):
348345
pass
349346

350347
@abstractmethod
351-
def _host_error(self):
348+
def _host_error(self) -> str:
352349
pass
353350

354351
def _error_message(self, exception):
@@ -377,7 +374,7 @@ def on_connect(self):
377374
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
378375
self._parser.on_connect(self)
379376
if len(auth_args) == 1:
380-
auth_args = ["default", auth_args[0]]
377+
auth_args = ("default", auth_args[0])
381378
self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
382379
response = self.read_response()
383380
# if response.get(b"proto") != self.protocol and response.get(
@@ -492,6 +489,7 @@ def send_packed_command(self, command, check_health=True):
492489
"""Send an already packed command to the Valkey server"""
493490
if not self._sock:
494491
self.connect()
492+
assert self._sock is not None
495493
# guard against health check recursion
496494
if check_health:
497495
self.check_health()
@@ -592,8 +590,8 @@ def pack_command(self, *args):
592590

593591
def pack_commands(self, commands):
594592
"""Pack multiple commands into the Valkey protocol"""
595-
output = []
596-
pieces = []
593+
output: list[bytes] = []
594+
pieces: list[bytes] = []
597595
buffer_length = 0
598596
buffer_cutoff = self._buffer_cutoff
599597

@@ -621,14 +619,15 @@ def pack_commands(self, commands):
621619
return output
622620

623621
def _cache_invalidation_process(
624-
self, data: List[Union[str, Optional[List[str]]]]
622+
self, data: list[Union[str, Optional[list[str]]]]
625623
) -> None:
626624
"""
627625
Invalidate (delete) all valkey commands associated with a specific key.
628626
`data` is a list of strings, where the first string is the invalidation message
629627
and the second string is the list of keys to invalidate.
630628
(if the list of keys is None, then all keys are invalidated)
631629
"""
630+
assert self.client_cache is not None
632631
if data[1] is None:
633632
self.client_cache.flush()
634633
else:
@@ -650,7 +649,7 @@ def _get_from_local_cache(self, command: Sequence[str]):
650649
return self.client_cache.get(command)
651650

652651
def _add_to_local_cache(
653-
self, command: Sequence[str], response: ResponseT, keys: List[KeysT]
652+
self, command: Sequence[str], response: ResponseT, keys: list[KeyT]
654653
):
655654
"""
656655
Add the command and response to the local cache if the command
@@ -671,7 +670,7 @@ def delete_command_from_cache(self, command: Union[str, Sequence[str]]):
671670
if self.client_cache:
672671
self.client_cache.delete_command(command)
673672

674-
def invalidate_key_from_cache(self, key: KeysT):
673+
def invalidate_key_from_cache(self, key: KeyT):
675674
if self.client_cache:
676675
self.client_cache.invalidate_key(key)
677676

@@ -1036,7 +1035,7 @@ def __init__(
10361035
self._fork_lock = threading.Lock()
10371036
self.reset()
10381037

1039-
def __repr__(self) -> (str, str):
1038+
def __repr__(self) -> str:
10401039
return (
10411040
f"<{type(self).__module__}.{type(self).__name__}"
10421041
f"({repr(self.connection_class(**self.connection_kwargs))})>"
@@ -1045,8 +1044,8 @@ def __repr__(self) -> (str, str):
10451044
def reset(self) -> None:
10461045
self._lock = threading.Lock()
10471046
self._created_connections = 0
1048-
self._available_connections = []
1049-
self._in_use_connections = set()
1047+
self._available_connections: list[Connection] = []
1048+
self._in_use_connections: set[Connection] = set()
10501049

10511050
# this must be the last operation in this method. while reset() is
10521051
# called when holding _fork_lock, other threads in this process
@@ -1193,7 +1192,7 @@ def disconnect(self, inuse_connections: bool = True) -> None:
11931192
self._checkpid()
11941193
with self._lock:
11951194
if inuse_connections:
1196-
connections = chain(
1195+
connections: Iterable[Connection] = chain(
11971196
self._available_connections, self._in_use_connections
11981197
)
11991198
else:
@@ -1388,7 +1387,7 @@ def release(self, connection):
13881387
# we don't want this connection
13891388
pass
13901389

1391-
def disconnect(self):
1390+
def disconnect(self, _: bool = True) -> None:
13921391
"Disconnects all connections in the pool."
13931392
self._checkpid()
13941393
for connection in self._connections:

valkey/retry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
T = TypeVar("T")
88

99
if TYPE_CHECKING:
10-
from redis.backoff import AbstractBackoff
10+
from valkey.backoff import AbstractBackoff
1111

1212

1313
class Retry:

0 commit comments

Comments
 (0)