@@ -94,7 +94,24 @@ def decorator(func):
9494 @wraps (func )
9595 def wrapper (self , * args , ** kwargs ):
9696 result = func (self , * args , ** kwargs )
97- RedisConnectionFactory .validate_redis (self ._redis_client , self ._lib_name )
97+ RedisConnectionFactory .validate_sync_redis (
98+ self ._redis_client , self ._lib_name
99+ )
100+ return result
101+
102+ return wrapper
103+
104+ return decorator
105+
106+
107+ def setup_async_redis ():
108+ def decorator (func ):
109+ @wraps (func )
110+ async def wrapper (self , * args , ** kwargs ):
111+ result = await func (self , * args , ** kwargs )
112+ await RedisConnectionFactory .validate_async_redis (
113+ self ._redis_client , self ._lib_name
114+ )
98115 return result
99116
100117 return wrapper
@@ -140,41 +157,10 @@ class BaseSearchIndex:
140157 StorageType .JSON : JsonStorage ,
141158 }
142159
143- def __init__ (
144- self ,
145- schema : IndexSchema ,
146- redis_client : Optional [Union [redis .Redis , aredis .Redis ]] = None ,
147- redis_url : Optional [str ] = None ,
148- connection_args : Dict [str , Any ] = {},
149- ** kwargs ,
150- ):
151- """Initialize the RedisVL search index with a schema, Redis client
152- (or URL string with other connection args), connection_args, and other
153- kwargs.
154-
155- Args:
156- schema (IndexSchema): Index schema object.
157- redis_client(Union[redis.Redis, aredis.Redis], optional): An
158- instantiated redis client.
159- redis_url (str, optional): The URL of the Redis server to
160- connect to.
161- connection_args (Dict[str, Any], optional): Redis client connection
162- args.
163- """
164- # final validation on schema object
165- if not isinstance (schema , IndexSchema ):
166- raise ValueError ("Must provide a valid IndexSchema object" )
167-
168- self .schema = schema
169-
170- self ._lib_name : Optional [str ] = kwargs .pop ("lib_name" , None )
160+ schema : IndexSchema
171161
172- # set up redis connection
173- self ._redis_client : Optional [Union [redis .Redis , aredis .Redis ]] = None
174- if redis_client is not None :
175- self .set_client (redis_client )
176- elif redis_url is not None :
177- self .connect (redis_url , ** connection_args )
162+ def __init__ (* args , ** kwargs ):
163+ pass
178164
179165 @property
180166 def _storage (self ) -> BaseStorage :
@@ -237,8 +223,6 @@ def from_dict(cls, schema_dict: Dict[str, Any], **kwargs):
237223
238224 Args:
239225 schema_dict (Dict[str, Any]): A dictionary containing the schema.
240- connection_args (Dict[str, Any], optional): Redis client connection
241- args.
242226
243227 Returns:
244228 SearchIndex: A RedisVL SearchIndex object.
@@ -262,14 +246,6 @@ def from_dict(cls, schema_dict: Dict[str, Any], **kwargs):
262246 schema = IndexSchema .from_dict (schema_dict )
263247 return cls (schema = schema , ** kwargs )
264248
265- def connect (self , redis_url : Optional [str ] = None , ** kwargs ):
266- """Connect to Redis at a given URL."""
267- raise NotImplementedError
268-
269- def set_client (self , client : Union [redis .Redis , aredis .Redis ]):
270- """Manually set the Redis client to use with the search index."""
271- raise NotImplementedError
272-
273249 def disconnect (self ):
274250 """Disconnect from the Redis database."""
275251 self ._redis_client = None
@@ -323,6 +299,43 @@ class SearchIndex(BaseSearchIndex):
323299
324300 """
325301
302+ def __init__ (
303+ self ,
304+ schema : IndexSchema ,
305+ redis_client : Optional [redis .Redis ] = None ,
306+ redis_url : Optional [str ] = None ,
307+ connection_args : Dict [str , Any ] = {},
308+ ** kwargs ,
309+ ):
310+ """Initialize the RedisVL search index with a schema, Redis client
311+ (or URL string with other connection args), connection_args, and other
312+ kwargs.
313+
314+ Args:
315+ schema (IndexSchema): Index schema object.
316+ redis_client(Optional[redis.Redis]): An
317+ instantiated redis client.
318+ redis_url (Optional[str]): The URL of the Redis server to
319+ connect to.
320+ connection_args (Dict[str, Any], optional): Redis client connection
321+ args.
322+ """
323+ # final validation on schema object
324+ if not isinstance (schema , IndexSchema ):
325+ raise ValueError ("Must provide a valid IndexSchema object" )
326+
327+ self .schema = schema
328+
329+ self ._lib_name : Optional [str ] = kwargs .pop ("lib_name" , None )
330+
331+ # set up redis connection
332+ self ._redis_client : Optional [redis .Redis ] = None
333+
334+ if redis_client is not None :
335+ self .set_client (redis_client )
336+ elif redis_url is not None :
337+ self .connect (redis_url , ** connection_args )
338+
326339 @classmethod
327340 def from_existing (
328341 cls ,
@@ -342,7 +355,7 @@ def from_existing(
342355 )
343356
344357 # Validate modules
345- installed_modules = RedisConnectionFactory ._get_modules (redis_client )
358+ installed_modules = RedisConnectionFactory .get_modules (redis_client )
346359 validate_modules (installed_modules , [{"name" : "search" , "ver" : 20810 }])
347360
348361 # Fetch index info and convert to schema
@@ -380,15 +393,15 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
380393 return self .set_client (client )
381394
382395 @setup_redis ()
383- def set_client (self , client : redis .Redis , ** kwargs ):
396+ def set_client (self , redis_client : redis .Redis , ** kwargs ):
384397 """Manually set the Redis client to use with the search index.
385398
386399 This method configures the search index to use a specific Redis or
387400 Async Redis client. It is useful for cases where an external,
388401 custom-configured client is preferred instead of creating a new one.
389402
390403 Args:
391- client (redis.Redis): A Redis or Async Redis
404+ redis_client (redis.Redis): A Redis or Async Redis
392405 client instance to be used for the connection.
393406
394407 Raises:
@@ -404,10 +417,10 @@ def set_client(self, client: redis.Redis, **kwargs):
404417 index.set_client(client)
405418
406419 """
407- if not isinstance (client , redis .Redis ):
420+ if not isinstance (redis_client , redis .Redis ):
408421 raise TypeError ("Invalid Redis client instance" )
409422
410- self ._redis_client = client
423+ self ._redis_client = redis_client
411424
412425 return self
413426
@@ -759,7 +772,7 @@ class AsyncSearchIndex(BaseSearchIndex):
759772
760773 # initialize the index object with schema from file
761774 index = AsyncSearchIndex.from_yaml("schemas/schema.yaml")
762- index.connect(redis_url="redis://localhost:6379")
775+ await index.connect(redis_url="redis://localhost:6379")
763776
764777 # create the index
765778 await index.create(overwrite=True)
@@ -772,6 +785,34 @@ class AsyncSearchIndex(BaseSearchIndex):
772785
773786 """
774787
788+ def __init__ (
789+ self ,
790+ schema : IndexSchema ,
791+ ** kwargs ,
792+ ):
793+ """Initialize the RedisVL async search index with a schema.
794+
795+ Args:
796+ schema (IndexSchema): Index schema object.
797+ connection_args (Dict[str, Any], optional): Redis client connection
798+ args.
799+ """
800+ # final validation on schema object
801+ if not isinstance (schema , IndexSchema ):
802+ raise ValueError ("Must provide a valid IndexSchema object" )
803+
804+ self .schema = schema
805+
806+ self ._lib_name : Optional [str ] = kwargs .pop ("lib_name" , None )
807+
808+ # set up empty redis connection
809+ self ._redis_client : Optional [aredis .Redis ] = None
810+
811+ if "redis_client" in kwargs or "redis_url" in kwargs :
812+ logger .warning (
813+ "Must use set_client() or connect() methods to provide a Redis connection to AsyncSearchIndex"
814+ )
815+
775816 @classmethod
776817 async def from_existing (
777818 cls ,
@@ -791,18 +832,18 @@ async def from_existing(
791832 )
792833
793834 # Validate modules
794- installed_modules = await RedisConnectionFactory ._get_modules_async (
795- redis_client
796- )
835+ installed_modules = await RedisConnectionFactory .get_modules_async (redis_client )
797836 validate_modules (installed_modules , [{"name" : "search" , "ver" : 20810 }])
798837
799838 # Fetch index info and convert to schema
800839 index_info = await cls ._info (name , redis_client )
801840 schema_dict = convert_index_info_to_schema (index_info )
802841 schema = IndexSchema .from_dict (schema_dict )
803- return cls (schema , redis_client , ** kwargs )
842+ index = cls (schema , ** kwargs )
843+ await index .set_client (redis_client )
844+ return index
804845
805- def connect (self , redis_url : Optional [str ] = None , ** kwargs ):
846+ async def connect (self , redis_url : Optional [str ] = None , ** kwargs ):
806847 """Connect to a Redis instance using the provided `redis_url`, falling
807848 back to the `REDIS_URL` environment variable (if available).
808849
@@ -828,18 +869,18 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
828869 client = RedisConnectionFactory .connect (
829870 redis_url = redis_url , use_async = True , ** kwargs
830871 )
831- return self .set_client (client )
872+ return await self .set_client (client )
832873
833- @setup_redis ()
834- def set_client (self , client : aredis .Redis ):
874+ @setup_async_redis ()
875+ async def set_client (self , redis_client : aredis .Redis ):
835876 """Manually set the Redis client to use with the search index.
836877
837878 This method configures the search index to use a specific
838879 Async Redis client. It is useful for cases where an external,
839880 custom-configured client is preferred instead of creating a new one.
840881
841882 Args:
842- client (aredis.Redis): An Async Redis
883+ redis_client (aredis.Redis): An Async Redis
843884 client instance to be used for the connection.
844885
845886 Raises:
@@ -853,13 +894,13 @@ def set_client(self, client: aredis.Redis):
853894 # async Redis client and index
854895 client = aredis.Redis.from_url("redis://localhost:6379")
855896 index = AsyncSearchIndex.from_yaml("schemas/schema.yaml")
856- index.set_client(client)
897+ await index.set_client(client)
857898
858899 """
859- if not isinstance (client , aredis .Redis ):
900+ if not isinstance (redis_client , aredis .Redis ):
860901 raise TypeError ("Invalid Redis client instance" )
861902
862- self ._redis_client = client
903+ self ._redis_client = redis_client
863904
864905 return self
865906
@@ -889,6 +930,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
889930 await index.create(overwrite=True, drop=True)
890931 """
891932 redis_fields = self .schema .redis_fields
933+
892934 if not redis_fields :
893935 raise ValueError ("No fields defined for index" )
894936 if not isinstance (overwrite , bool ):
0 commit comments