2222from redisvl .index import AsyncSearchIndex , SearchIndex
2323from redisvl .query import RangeQuery
2424from redisvl .query .filter import FilterExpression
25- from redisvl .utils .utils import current_timestamp , serialize , validate_vector_dims
25+ from redisvl .utils .utils import (
26+ current_timestamp ,
27+ deprecated_argument ,
28+ serialize ,
29+ validate_vector_dims ,
30+ )
2631from redisvl .utils .vectorize import BaseVectorizer , HFTextVectorizer
2732
2833
@@ -32,6 +37,7 @@ class SemanticCache(BaseLLMCache):
3237 _index : SearchIndex
3338 _aindex : Optional [AsyncSearchIndex ] = None
3439
40+ @deprecated_argument ("dtype" , "vectorizer" )
3541 def __init__ (
3642 self ,
3743 name : str = "llmcache" ,
@@ -86,12 +92,26 @@ def __init__(
8692 else :
8793 prefix = name
8894
89- # Set vectorizer default
90- if vectorizer is None :
95+ dtype = kwargs .get ("dtype" )
96+
97+ # Validate a provided vectorizer or set the default
98+ if vectorizer :
99+ if not isinstance (vectorizer , BaseVectorizer ):
100+ raise TypeError ("Must provide a valid redisvl.vectorizer class." )
101+ if dtype and vectorizer .dtype != dtype :
102+ raise ValueError (
103+ f"Provided dtype { dtype } does not match vectorizer dtype { vectorizer .dtype } "
104+ )
105+ else :
106+ vectorizer_kwargs = {"dtype" : dtype } if dtype else {}
107+
91108 vectorizer = HFTextVectorizer (
92- model = "sentence-transformers/all-mpnet-base-v2"
109+ model = "sentence-transformers/all-mpnet-base-v2" ,
110+ ** vectorizer_kwargs ,
93111 )
94112
113+ self ._vectorizer = vectorizer
114+
95115 # Process fields and other settings
96116 self .set_threshold (distance_threshold )
97117 self .return_fields = [
@@ -104,9 +124,8 @@ def __init__(
104124 ]
105125
106126 # Create semantic cache schema and index
107- dtype = kwargs .get ("dtype" , "float32" )
108127 schema = SemanticCacheIndexSchema .from_params (
109- name , prefix , vectorizer .dims , dtype
128+ name , prefix , vectorizer .dims , vectorizer . dtype
110129 )
111130 schema = self ._modify_schema (schema , filterable_fields )
112131 self ._index = SearchIndex (schema = schema )
@@ -128,20 +147,9 @@ def __init__(
128147 "If you wish to overwrite the index schema, set overwrite=True during initialization."
129148 )
130149
131- # Create the search index
150+ # Create the search index in Redis
132151 self ._index .create (overwrite = overwrite , drop = False )
133152
134- # Initialize and validate vectorizer
135- if not isinstance (vectorizer , BaseVectorizer ):
136- raise TypeError ("Must provide a valid redisvl.vectorizer class." )
137-
138- validate_vector_dims (
139- vectorizer .dims ,
140- self ._index .schema .fields [CACHE_VECTOR_FIELD_NAME ].attrs .dims , # type: ignore
141- )
142- self ._vectorizer = vectorizer
143- self ._dtype = self .index .schema .fields [CACHE_VECTOR_FIELD_NAME ].attrs .datatype # type: ignore[union-attr]
144-
145153 def _modify_schema (
146154 self ,
147155 schema : SemanticCacheIndexSchema ,
@@ -290,7 +298,7 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
290298 if not isinstance (prompt , str ):
291299 raise TypeError ("Prompt must be a string." )
292300
293- return self ._vectorizer .embed (prompt , dtype = self . _dtype )
301+ return self ._vectorizer .embed (prompt )
294302
295303 async def _avectorize_prompt (self , prompt : Optional [str ]) -> List [float ]:
296304 """Converts a text prompt to its vector representation using the
@@ -372,7 +380,7 @@ def check(
372380 num_results = num_results ,
373381 return_score = True ,
374382 filter_expression = filter_expression ,
375- dtype = self ._dtype ,
383+ dtype = self ._vectorizer . dtype ,
376384 )
377385
378386 # Search the cache!
@@ -543,7 +551,7 @@ def store(
543551 # Load cache entry with TTL
544552 ttl = ttl or self ._ttl
545553 keys = self ._index .load (
546- data = [cache_entry .to_dict (self ._dtype )],
554+ data = [cache_entry .to_dict (self ._vectorizer . dtype )],
547555 ttl = ttl ,
548556 id_field = ENTRY_ID_FIELD_NAME ,
549557 )
@@ -607,7 +615,7 @@ async def astore(
607615 # Load cache entry with TTL
608616 ttl = ttl or self ._ttl
609617 keys = await aindex .load (
610- data = [cache_entry .to_dict (self ._dtype )],
618+ data = [cache_entry .to_dict (self ._vectorizer . dtype )],
611619 ttl = ttl ,
612620 id_field = ENTRY_ID_FIELD_NAME ,
613621 )
0 commit comments