Skip to content

Commit a6e6a25

Browse files
committed
Fix various commands to work with Redis Cluster
1 parent 3c0f9b9 commit a6e6a25

23 files changed

+1281
-326
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,4 @@ libs/redis/docs/.Trash*
222222
.idea/*
223223
.vscode/settings.json
224224
.python-version
225+
tests/data

poetry.lock

Lines changed: 19 additions & 31 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ numpy = [
2626
{ version = ">=1.26.0,<3", python = ">=3.12" },
2727
]
2828
pyyaml = ">=5.4,<7.0"
29-
redis = "^5.0"
29+
redis = "^6.0"
3030
pydantic = "^2"
3131
tenacity = ">=8.2.2"
3232
ml-dtypes = ">=0.4.0,<1.0.0"
@@ -68,8 +68,8 @@ pytest-xdist = {extras = ["psutil"], version = "^3.6.1"}
6868
pre-commit = "^4.1.0"
6969
mypy = "1.9.0"
7070
nbval = "^0.11.0"
71-
types-redis = "*"
7271
types-pyyaml = "*"
72+
types-pyopenssl = "*"
7373
testcontainers = "^4.3.1"
7474
cryptography = { version = ">=44.0.1", markers = "python_version > '3.9.1'" }
7575

redisvl/extensions/cache/base.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
specific cache types such as LLM caches and embedding caches.
55
"""
66

7-
from typing import Any, Dict, Optional
7+
from collections.abc import Mapping
8+
from typing import Any, Dict, Optional, Union
89

9-
from redis import Redis
10-
from redis.asyncio import Redis as AsyncRedis
10+
from redis import Redis # For backwards compatibility in type checking
11+
from redis.cluster import RedisCluster
1112

1213
from redisvl.redis.connection import RedisConnectionFactory
14+
from redisvl.types import AsyncRedisClient, SyncRedisClient, SyncRedisCluster
1315

1416

1517
class BaseCache:
@@ -19,14 +21,15 @@ class BaseCache:
1921
including TTL management, connection handling, and basic cache operations.
2022
"""
2123

22-
_redis_client: Optional[Redis]
23-
_async_redis_client: Optional[AsyncRedis]
24+
_redis_client: Optional[SyncRedisClient]
25+
_async_redis_client: Optional[AsyncRedisClient]
2426

2527
def __init__(
2628
self,
2729
name: str,
2830
ttl: Optional[int] = None,
29-
redis_client: Optional[Redis] = None,
31+
redis_client: Optional[SyncRedisClient] = None,
32+
async_redis_client: Optional[AsyncRedisClient] = None,
3033
redis_url: str = "redis://localhost:6379",
3134
connection_kwargs: Dict[str, Any] = {},
3235
):
@@ -36,7 +39,7 @@ def __init__(
3639
name (str): The name of the cache.
3740
ttl (Optional[int], optional): The time-to-live for records cached
3841
in Redis. Defaults to None.
39-
redis_client (Optional[Redis], optional): A redis client connection instance.
42+
redis_client (Optional[SyncRedisClient], optional): A redis client connection instance.
4043
Defaults to None.
4144
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
4245
connection_kwargs (Dict[str, Any]): The connection arguments
@@ -53,14 +56,13 @@ def __init__(
5356
}
5457

5558
# Initialize Redis clients
56-
self._async_redis_client = None
59+
self._async_redis_client = async_redis_client
60+
self._redis_client = redis_client
5761

58-
if redis_client:
62+
if redis_client or async_redis_client:
5963
self._owns_redis_client = False
60-
self._redis_client = redis_client
6164
else:
6265
self._owns_redis_client = True
63-
self._redis_client = None # type: ignore
6466

6567
def _get_prefix(self) -> str:
6668
"""Get the key prefix for Redis keys.
@@ -103,11 +105,11 @@ def set_ttl(self, ttl: Optional[int] = None) -> None:
103105
else:
104106
self._ttl = None
105107

106-
def _get_redis_client(self) -> Redis:
108+
def _get_redis_client(self) -> SyncRedisClient:
107109
"""Get or create a Redis client.
108110
109111
Returns:
110-
Redis: A Redis client instance.
112+
SyncRedisClient: A Redis client instance.
111113
"""
112114
if self._redis_client is None:
113115
# Create new Redis client
@@ -116,22 +118,29 @@ def _get_redis_client(self) -> Redis:
116118
self._redis_client = Redis.from_url(url, **kwargs) # type: ignore
117119
return self._redis_client
118120

119-
async def _get_async_redis_client(self) -> AsyncRedis:
121+
async def _get_async_redis_client(self) -> AsyncRedisClient:
120122
"""Get or create an async Redis client.
121123
122124
Returns:
123-
AsyncRedis: An async Redis client instance.
125+
AsyncRedisClient: An async Redis client instance.
124126
"""
125127
if not hasattr(self, "_async_redis_client") or self._async_redis_client is None:
126128
client = self.redis_kwargs.get("redis_client")
127-
if isinstance(client, Redis):
129+
130+
if client and isinstance(client, (Redis, RedisCluster)):
128131
self._async_redis_client = RedisConnectionFactory.sync_to_async_redis(
129132
client
130133
)
131134
else:
132-
url = self.redis_kwargs["redis_url"]
133-
kwargs = self.redis_kwargs["connection_kwargs"]
134-
self._async_redis_client = RedisConnectionFactory.get_async_redis_connection(url, **kwargs) # type: ignore
135+
url = str(self.redis_kwargs["redis_url"])
136+
kwargs = self.redis_kwargs.get("connection_kwargs", {})
137+
if not isinstance(kwargs, Mapping):
138+
raise ValueError(
139+
f"connection_kwargs must be a mapping, got {type(kwargs)}"
140+
)
141+
self._async_redis_client = (
142+
RedisConnectionFactory.get_async_redis_connection(url, **kwargs)
143+
)
135144
return self._async_redis_client
136145

137146
def expire(self, key: str, ttl: Optional[int] = None) -> None:
@@ -183,7 +192,14 @@ def clear(self) -> None:
183192
client.delete(*keys)
184193
if cursor_int == 0: # Redis returns 0 when scan is complete
185194
break
186-
cursor = cursor_int # Update cursor for next iteration
195+
# Cluster returns a dict of cursor values. We need to stop if these all
196+
# come back as 0.
197+
elif isinstance(cursor_int, Mapping):
198+
cursor_values = list(cursor_int.values())
199+
if all(v == 0 for v in cursor_values):
200+
break
201+
else:
202+
cursor = cursor_int # Update cursor for next iteration
187203

188204
async def aclear(self) -> None:
189205
"""Async clear the cache of all keys."""
@@ -193,7 +209,9 @@ async def aclear(self) -> None:
193209
# Scan for all keys with our prefix
194210
cursor = 0 # Start with cursor 0
195211
while True:
196-
cursor_int, keys = await client.scan(cursor=cursor, match=f"{prefix}*", count=100) # type: ignore
212+
cursor_int, keys = await client.scan(
213+
cursor=cursor, match=f"{prefix}*", count=100
214+
) # type: ignore
197215
if keys:
198216
await client.delete(*keys)
199217
if cursor_int == 0: # Redis returns 0 when scan is complete
@@ -207,12 +225,10 @@ def disconnect(self) -> None:
207225

208226
if self._redis_client:
209227
self._redis_client.close()
210-
self._redis_client = None # type: ignore
211-
212-
if hasattr(self, "_async_redis_client") and self._async_redis_client:
213-
# Use synchronous close for async client in synchronous context
214-
self._async_redis_client.close() # type: ignore
215-
self._async_redis_client = None # type: ignore
228+
self._redis_client = None
229+
# Async clients don't have a sync close method, so we just
230+
# zero them out to allow garbage collection.
231+
self._async_redis_client = None
216232

217233
async def adisconnect(self) -> None:
218234
"""Async disconnect from Redis."""
@@ -221,9 +237,9 @@ async def adisconnect(self) -> None:
221237

222238
if self._redis_client:
223239
self._redis_client.close()
224-
self._redis_client = None # type: ignore
240+
self._redis_client = None
225241

226242
if hasattr(self, "_async_redis_client") and self._async_redis_client:
227243
# Use proper async close method
228-
await self._async_redis_client.aclose() # type: ignore
229-
self._async_redis_client = None # type: ignore
244+
await self._async_redis_client.aclose()
245+
self._async_redis_client = None

redisvl/extensions/cache/embeddings/embeddings.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""Embeddings cache implementation for RedisVL."""
22

3-
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
4-
5-
from redis import Redis
6-
from redis.asyncio import Redis as AsyncRedis
3+
from typing import Any, Awaitable, Dict, List, Optional, Tuple, cast
74

85
from redisvl.extensions.cache.base import BaseCache
96
from redisvl.extensions.cache.embeddings.schema import CacheEntry
107
from redisvl.redis.utils import convert_bytes, hashify
8+
from redisvl.types import AsyncRedisClient, SyncRedisClient
9+
from redisvl.utils.log import get_logger
10+
11+
logger = get_logger(__name__)
1112

1213

1314
class EmbeddingsCache(BaseCache):
@@ -17,7 +18,8 @@ def __init__(
1718
self,
1819
name: str = "embedcache",
1920
ttl: Optional[int] = None,
20-
redis_client: Optional[Redis] = None,
21+
redis_client: Optional[SyncRedisClient] = None,
22+
async_redis_client: Optional[AsyncRedisClient] = None,
2123
redis_url: str = "redis://localhost:6379",
2224
connection_kwargs: Dict[str, Any] = {},
2325
):
@@ -26,7 +28,7 @@ def __init__(
2628
Args:
2729
name (str): The name of the cache. Defaults to "embedcache".
2830
ttl (Optional[int]): The time-to-live for cached embeddings. Defaults to None.
29-
redis_client (Optional[Redis]): Redis client instance. Defaults to None.
31+
redis_client (Optional[SyncRedisClient]): Redis client instance. Defaults to None.
3032
redis_url (str): Redis URL for connection. Defaults to "redis://localhost:6379".
3133
connection_kwargs (Dict[str, Any]): Redis connection arguments. Defaults to {}.
3234
@@ -173,7 +175,7 @@ def get_by_key(self, key: str) -> Optional[Dict[str, Any]]:
173175
if data:
174176
self.expire(key)
175177

176-
return self._process_cache_data(data)
178+
return self._process_cache_data(data) # type: ignore
177179

178180
def mget_by_keys(self, keys: List[str]) -> List[Optional[Dict[str, Any]]]:
179181
"""Get multiple embeddings by their Redis keys.
@@ -570,7 +572,7 @@ async def aget_by_key(self, key: str) -> Optional[Dict[str, Any]]:
570572
client = await self._get_async_redis_client()
571573

572574
# Get all fields
573-
data = await client.hgetall(key)
575+
data = await client.hgetall(key) # type: ignore
574576

575577
# Refresh TTL if data exists
576578
if data:
@@ -608,7 +610,7 @@ async def amget_by_keys(self, keys: List[str]) -> List[Optional[Dict[str, Any]]]
608610
async with client.pipeline(transaction=False) as pipeline:
609611
# Queue all hgetall operations
610612
for key in keys:
611-
await pipeline.hgetall(key)
613+
pipeline.hgetall(key)
612614
results = await pipeline.execute()
613615

614616
# Process results and refresh TTLs separately

0 commit comments

Comments
 (0)