44specific cache types such as LLM caches and embedding caches.
55"""
66
7- from typing import Any , Dict , Optional
7+ from collections .abc import Mapping
8+ from typing import Any , Dict , Optional , Union
89
9- from redis import Redis
10- from redis .asyncio import Redis as AsyncRedis
10+ from redis import Redis # For backwards compatibility in type checking
11+ from redis .cluster import RedisCluster
1112
1213from redisvl .redis .connection import RedisConnectionFactory
14+ from redisvl .types import AsyncRedisClient , SyncRedisClient , SyncRedisCluster
1315
1416
1517class BaseCache :
@@ -19,14 +21,15 @@ class BaseCache:
1921 including TTL management, connection handling, and basic cache operations.
2022 """
2123
22- _redis_client : Optional [Redis ]
23- _async_redis_client : Optional [AsyncRedis ]
24+ _redis_client : Optional [SyncRedisClient ]
25+ _async_redis_client : Optional [AsyncRedisClient ]
2426
2527 def __init__ (
2628 self ,
2729 name : str ,
2830 ttl : Optional [int ] = None ,
29- redis_client : Optional [Redis ] = None ,
31+ redis_client : Optional [SyncRedisClient ] = None ,
32+ async_redis_client : Optional [AsyncRedisClient ] = None ,
3033 redis_url : str = "redis://localhost:6379" ,
3134 connection_kwargs : Dict [str , Any ] = {},
3235 ):
@@ -36,7 +39,7 @@ def __init__(
3639 name (str): The name of the cache.
3740 ttl (Optional[int], optional): The time-to-live for records cached
3841 in Redis. Defaults to None.
39- redis_client (Optional[Redis ], optional): A redis client connection instance.
42+ redis_client (Optional[SyncRedisClient ], optional): A redis client connection instance.
4043 Defaults to None.
4144 redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
4245 connection_kwargs (Dict[str, Any]): The connection arguments
@@ -53,14 +56,13 @@ def __init__(
5356 }
5457
5558 # Initialize Redis clients
56- self ._async_redis_client = None
59+ self ._async_redis_client = async_redis_client
60+ self ._redis_client = redis_client
5761
58- if redis_client :
62+ if redis_client or async_redis_client :
5963 self ._owns_redis_client = False
60- self ._redis_client = redis_client
6164 else :
6265 self ._owns_redis_client = True
63- self ._redis_client = None # type: ignore
6466
6567 def _get_prefix (self ) -> str :
6668 """Get the key prefix for Redis keys.
@@ -103,11 +105,11 @@ def set_ttl(self, ttl: Optional[int] = None) -> None:
103105 else :
104106 self ._ttl = None
105107
106- def _get_redis_client (self ) -> Redis :
108+ def _get_redis_client (self ) -> SyncRedisClient :
107109 """Get or create a Redis client.
108110
109111 Returns:
110- Redis : A Redis client instance.
112+ SyncRedisClient : A Redis client instance.
111113 """
112114 if self ._redis_client is None :
113115 # Create new Redis client
@@ -116,22 +118,29 @@ def _get_redis_client(self) -> Redis:
116118 self ._redis_client = Redis .from_url (url , ** kwargs ) # type: ignore
117119 return self ._redis_client
118120
119- async def _get_async_redis_client (self ) -> AsyncRedis :
121+ async def _get_async_redis_client (self ) -> AsyncRedisClient :
120122 """Get or create an async Redis client.
121123
122124 Returns:
123- AsyncRedis : An async Redis client instance.
125+ AsyncRedisClient : An async Redis client instance.
124126 """
125127 if not hasattr (self , "_async_redis_client" ) or self ._async_redis_client is None :
126128 client = self .redis_kwargs .get ("redis_client" )
127- if isinstance (client , Redis ):
129+
130+ if client and isinstance (client , (Redis , RedisCluster )):
128131 self ._async_redis_client = RedisConnectionFactory .sync_to_async_redis (
129132 client
130133 )
131134 else :
132- url = self .redis_kwargs ["redis_url" ]
133- kwargs = self .redis_kwargs ["connection_kwargs" ]
134- self ._async_redis_client = RedisConnectionFactory .get_async_redis_connection (url , ** kwargs ) # type: ignore
135+ url = str (self .redis_kwargs ["redis_url" ])
136+ kwargs = self .redis_kwargs .get ("connection_kwargs" , {})
137+ if not isinstance (kwargs , Mapping ):
138+ raise ValueError (
139+ f"connection_kwargs must be a mapping, got { type (kwargs )} "
140+ )
141+ self ._async_redis_client = (
142+ RedisConnectionFactory .get_async_redis_connection (url , ** kwargs )
143+ )
135144 return self ._async_redis_client
136145
137146 def expire (self , key : str , ttl : Optional [int ] = None ) -> None :
@@ -183,7 +192,14 @@ def clear(self) -> None:
183192 client .delete (* keys )
184193 if cursor_int == 0 : # Redis returns 0 when scan is complete
185194 break
186- cursor = cursor_int # Update cursor for next iteration
195+ # Cluster returns a dict of cursor values. We need to stop if these all
196+ # come back as 0.
197+ elif isinstance (cursor_int , Mapping ):
198+ cursor_values = list (cursor_int .values ())
199+ if all (v == 0 for v in cursor_values ):
200+ break
201+ else :
202+ cursor = cursor_int # Update cursor for next iteration
187203
188204 async def aclear (self ) -> None :
189205 """Async clear the cache of all keys."""
@@ -193,7 +209,9 @@ async def aclear(self) -> None:
193209 # Scan for all keys with our prefix
194210 cursor = 0 # Start with cursor 0
195211 while True :
196- cursor_int , keys = await client .scan (cursor = cursor , match = f"{ prefix } *" , count = 100 ) # type: ignore
212+ cursor_int , keys = await client .scan (
213+ cursor = cursor , match = f"{ prefix } *" , count = 100
214+ ) # type: ignore
197215 if keys :
198216 await client .delete (* keys )
199217 if cursor_int == 0 : # Redis returns 0 when scan is complete
@@ -207,12 +225,10 @@ def disconnect(self) -> None:
207225
208226 if self ._redis_client :
209227 self ._redis_client .close ()
210- self ._redis_client = None # type: ignore
211-
212- if hasattr (self , "_async_redis_client" ) and self ._async_redis_client :
213- # Use synchronous close for async client in synchronous context
214- self ._async_redis_client .close () # type: ignore
215- self ._async_redis_client = None # type: ignore
228+ self ._redis_client = None
229+ # Async clients don't have a sync close method, so we just
230+ # zero them out to allow garbage collection.
231+ self ._async_redis_client = None
216232
217233 async def adisconnect (self ) -> None :
218234 """Async disconnect from Redis."""
@@ -221,9 +237,9 @@ async def adisconnect(self) -> None:
221237
222238 if self ._redis_client :
223239 self ._redis_client .close ()
224- self ._redis_client = None # type: ignore
240+ self ._redis_client = None
225241
226242 if hasattr (self , "_async_redis_client" ) and self ._async_redis_client :
227243 # Use proper async close method
228- await self ._async_redis_client .aclose () # type: ignore
229- self ._async_redis_client = None # type: ignore
244+ await self ._async_redis_client .aclose ()
245+ self ._async_redis_client = None
0 commit comments