1- from typing import Any , Dict , List , Optional
1+ from typing import Any , Dict , List , Optional , Union
22
33from redis import Redis
44
55from redisvl .extensions .llmcache .base import BaseLLMCache
66from redisvl .index import SearchIndex
77from redisvl .query import RangeQuery
8+ from redisvl .query .filter import FilterExpression , Tag
89from redisvl .redis .utils import array_to_buffer
9- from redisvl .schema .schema import IndexSchema
10+ from redisvl .schema import IndexSchema
11+ from redisvl .utils .utils import current_timestamp , deserialize , serialize
1012from redisvl .utils .vectorize import BaseVectorizer , HFTextVectorizer
1113
1214
15+ class SemanticCacheIndexSchema (IndexSchema ):
16+
17+ @classmethod
18+ def from_params (cls , name : str , vector_dims : int ):
19+
20+ return cls (
21+ index = {"name" : name , "prefix" : name }, # type: ignore
22+ fields = [ # type: ignore
23+ {"name" : "prompt" , "type" : "text" },
24+ {"name" : "response" , "type" : "text" },
25+ {"name" : "inserted_at" , "type" : "numeric" },
26+ {"name" : "updated_at" , "type" : "numeric" },
27+ {"name" : "label" , "type" : "tag" },
28+ {
29+ "name" : "prompt_vector" ,
30+ "type" : "vector" ,
31+ "attrs" : {
32+ "dims" : vector_dims ,
33+ "datatype" : "float32" ,
34+ "distance_metric" : "cosine" ,
35+ "algorithm" : "flat" ,
36+ },
37+ },
38+ ],
39+ )
40+
41+
1342class SemanticCache (BaseLLMCache ):
1443 """Semantic Cache for Large Language Models."""
1544
1645 entry_id_field_name : str = "_id"
1746 prompt_field_name : str = "prompt"
1847 vector_field_name : str = "prompt_vector"
48+ inserted_at_field_name : str = "inserted_at"
49+ updated_at_field_name : str = "updated_at"
50+ tag_field_name : str = "label"
1951 response_field_name : str = "response"
2052 metadata_field_name : str = "metadata"
2153
@@ -69,27 +101,7 @@ def __init__(
69101 model = "sentence-transformers/all-mpnet-base-v2"
70102 )
71103
72- # build cache index schema
73- schema = IndexSchema .from_dict ({"index" : {"name" : name , "prefix" : prefix }})
74- # add fields
75- schema .add_fields (
76- [
77- {"name" : self .prompt_field_name , "type" : "text" },
78- {"name" : self .response_field_name , "type" : "text" },
79- {
80- "name" : self .vector_field_name ,
81- "type" : "vector" ,
82- "attrs" : {
83- "dims" : vectorizer .dims ,
84- "datatype" : "float32" ,
85- "distance_metric" : "cosine" ,
86- "algorithm" : "flat" ,
87- },
88- },
89- ]
90- )
91-
92- # build search index
104+ schema = SemanticCacheIndexSchema .from_params (name , vectorizer .dims )
93105 self ._index = SearchIndex (schema = schema )
94106
95107 # handle redis connection
@@ -103,12 +115,12 @@ def __init__(
103115 self .entry_id_field_name ,
104116 self .prompt_field_name ,
105117 self .response_field_name ,
118+ self .tag_field_name ,
106119 self .vector_field_name ,
107120 self .metadata_field_name ,
108121 ]
109122 self .set_vectorizer (vectorizer )
110123 self .set_threshold (distance_threshold )
111-
112124 self ._index .create (overwrite = False )
113125
114126 @property
@@ -182,6 +194,14 @@ def delete(self) -> None:
182194 index."""
183195 self ._index .delete (drop = True )
184196
197+ def drop (self , document_ids : Union [str , List [str ]]) -> None :
198+ """Remove a specific entry or entries from the cache by it's ID.
199+
200+ Args:
201+ document_ids (Union[str, List[str]]): The document ID or IDs to remove from the cache.
202+ """
203+ self ._index .drop_keys (document_ids )
204+
185205 def _refresh_ttl (self , key : str ) -> None :
186206 """Refresh the time-to-live for the specified key."""
187207 if self ._ttl :
@@ -195,7 +215,11 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
195215 return self ._vectorizer .embed (prompt )
196216
197217 def _search_cache (
198- self , vector : List [float ], num_results : int , return_fields : Optional [List [str ]]
218+ self ,
219+ vector : List [float ],
220+ num_results : int ,
221+ return_fields : Optional [List [str ]],
222+ tag_filter : Optional [FilterExpression ],
199223 ) -> List [Dict [str , Any ]]:
200224 """Searches the semantic cache for similar prompt vectors and returns
201225 the specified return fields for each cache hit."""
@@ -217,6 +241,8 @@ def _search_cache(
217241 num_results = num_results ,
218242 return_score = True ,
219243 )
244+ if tag_filter :
245+ query .set_filter (tag_filter ) # type: ignore
220246
221247 # Gather and return the cache hits
222248 cache_hits : List [Dict [str , Any ]] = self ._index .query (query )
@@ -226,7 +252,7 @@ def _search_cache(
226252 self ._refresh_ttl (key )
227253 # Check for metadata and deserialize
228254 if self .metadata_field_name in hit :
229- hit [self .metadata_field_name ] = self . deserialize (
255+ hit [self .metadata_field_name ] = deserialize (
230256 hit [self .metadata_field_name ]
231257 )
232258 return cache_hits
@@ -248,6 +274,7 @@ def check(
248274 vector : Optional [List [float ]] = None ,
249275 num_results : int = 1 ,
250276 return_fields : Optional [List [str ]] = None ,
277+ tag_filter : Optional [FilterExpression ] = None ,
251278 ) -> List [Dict [str , Any ]]:
252279 """Checks the semantic cache for results similar to the specified prompt
253280 or vector.
@@ -267,6 +294,8 @@ def check(
267294 return_fields (Optional[List[str]], optional): The fields to include
268295 in each returned result. If None, defaults to all available
269296 fields in the cached entry.
297+ tag_filter (Optional[FilterExpression]) : the tag filter to filter
298+ results by. Default is None and full cache is searched.
270299
271300 Returns:
272301 List[Dict[str, Any]]: A list of dicts containing the requested
@@ -291,7 +320,7 @@ def check(
291320 self ._check_vector_dims (vector )
292321
293322 # Check for cache hits by searching the cache
294- cache_hits = self ._search_cache (vector , num_results , return_fields )
323+ cache_hits = self ._search_cache (vector , num_results , return_fields , tag_filter )
295324 return cache_hits
296325
297326 def store (
@@ -300,6 +329,7 @@ def store(
300329 response : str ,
301330 vector : Optional [List [float ]] = None ,
302331 metadata : Optional [dict ] = None ,
332+ tag : Optional [str ] = None ,
303333 ) -> str :
304334 """Stores the specified key-value pair in the cache along with metadata.
305335
@@ -311,6 +341,8 @@ def store(
311341 demand.
312342 metadata (Optional[dict], optional): The optional metadata to cache
313343 alongside the prompt and response. Defaults to None.
344+ tag (Optional[str]): The optional tag to assign to the cache entry.
345+ Defaults to None.
314346
315347 Returns:
316348 str: The Redis key for the entries added to the semantic cache.
@@ -333,19 +365,67 @@ def store(
333365 self ._check_vector_dims (vector )
334366
335367 # Construct semantic cache payload
368+ now = current_timestamp ()
336369 id_field = self .entry_id_field_name
337370 payload = {
338371 id_field : self .hash_input (prompt ),
339372 self .prompt_field_name : prompt ,
340373 self .response_field_name : response ,
341374 self .vector_field_name : array_to_buffer (vector ),
375+ self .inserted_at_field_name : now ,
376+ self .updated_at_field_name : now ,
342377 }
343378 if metadata is not None :
344379 if not isinstance (metadata , dict ):
345380 raise TypeError ("If specified, cached metadata must be a dictionary." )
346381 # Serialize the metadata dict and add to cache payload
347- payload [self .metadata_field_name ] = self .serialize (metadata )
382+ payload [self .metadata_field_name ] = serialize (metadata )
383+ if tag is not None :
384+ payload [self .tag_field_name ] = tag
348385
349386 # Load LLMCache entry with TTL
350387 keys = self ._index .load (data = [payload ], ttl = self ._ttl , id_field = id_field )
351388 return keys [0 ]
389+
390+ def update (self , key : str , ** kwargs ) -> None :
391+ """Update specific fields within an existing cache entry. If no fields
392+ are passed, then only the document TTL is refreshed.
393+
394+ Args:
395+ key (str): the key of the document to update.
396+ kwargs:
397+
398+ Raises:
399+ ValueError if an incorrect mapping is provided as a kwarg.
400+ TypeError if metadata is provided and not of type dict.
401+
402+ .. code-block:: python
403+ key = cache.store('this is a prompt', 'this is a response')
404+ cache.update(key, metadata={"hit_count": 1, "model_name": "Llama-2-7b"})
405+ )
406+ """
407+ if not kwargs :
408+ self ._refresh_ttl (key )
409+ return
410+
411+ for _key , val in kwargs .items ():
412+ if _key not in {
413+ self .prompt_field_name ,
414+ self .vector_field_name ,
415+ self .response_field_name ,
416+ self .tag_field_name ,
417+ self .metadata_field_name ,
418+ }:
419+ raise ValueError (f" { key } is not a valid field within document" )
420+
421+ # Check for metadata and deserialize
422+ if _key == self .metadata_field_name :
423+ if isinstance (val , dict ):
424+ kwargs [_key ] = serialize (val )
425+ else :
426+ raise TypeError (
427+ "If specified, cached metadata must be a dictionary."
428+ )
429+ kwargs .update ({self .updated_at_field_name : current_timestamp ()})
430+ self ._index .client .hset (key , mapping = kwargs ) # type: ignore
431+ self ._refresh_ttl (key )
0 commit comments