1+ import asyncio
12from typing import Any , Dict , List , Optional
23
34from redis import Redis
@@ -341,8 +342,10 @@ def check(
341342 prompt="What is the captial city of France?"
342343 )
343344 """
344- if not ( prompt or vector ):
345+ if not any ([ prompt , vector ] ):
345346 raise ValueError ("Either prompt or vector must be specified." )
347+ if return_fields and not isinstance (return_fields , list ):
348+ raise TypeError ("Return fields must be a list of values." )
346349
347350 # overrides
348351 distance_threshold = distance_threshold or self ._distance_threshold
@@ -359,25 +362,14 @@ def check(
359362 filter_expression = filter_expression ,
360363 )
361364
362- cache_hits : List [Dict [Any , str ]] = []
363-
364365 # Search the cache!
365366 cache_search_results = self ._index .query (query )
366-
367- for cache_search_result in cache_search_results :
368- redis_key = cache_search_result .pop ("id" )
369- self ._refresh_ttl (redis_key )
370-
371- # Create and process cache hit
372- cache_hit = CacheHit (** cache_search_result )
373- cache_hit_dict = cache_hit .to_dict ()
374- # Filter down to only selected return fields if needed
375- if isinstance (return_fields , list ) and len (return_fields ) > 0 :
376- cache_hit_dict = {
377- k : v for k , v in cache_hit_dict .items () if k in return_fields
378- }
379- cache_hit_dict [self .redis_key_field_name ] = redis_key
380- cache_hits .append (cache_hit_dict )
367+ redis_keys , cache_hits = self ._process_cache_results (
368+ cache_search_results , return_fields # type: ignore
369+ )
370+ # Extend TTL on keys
371+ for key in redis_keys :
372+ self ._refresh_ttl (key )
381373
382374 return cache_hits
383375
@@ -431,19 +423,16 @@ async def acheck(
431423 """
432424 aindex = await self ._get_async_index ()
433425
434- if not ( prompt or vector ):
426+ if not any ([ prompt , vector ] ):
435427 raise ValueError ("Either prompt or vector must be specified." )
428+ if return_fields and not isinstance (return_fields , list ):
429+ raise TypeError ("Return fields must be a list of values." )
436430
437431 # overrides
438432 distance_threshold = distance_threshold or self ._distance_threshold
439- return_fields = return_fields or self .return_fields
440433 vector = vector or await self ._avectorize_prompt (prompt )
441-
442434 self ._check_vector_dims (vector )
443435
444- if not isinstance (return_fields , list ):
445- raise TypeError ("return_fields must be a list of field names" )
446-
447436 query = RangeQuery (
448437 vector = vector ,
449438 vector_field_name = self .vector_field_name ,
@@ -454,24 +443,36 @@ async def acheck(
454443 filter_expression = filter_expression ,
455444 )
456445
457- cache_hits : List [Dict [Any , str ]] = []
458-
459446 # Search the cache!
460447 cache_search_results = await aindex .query (query )
448+ redis_keys , cache_hits = self ._process_cache_results (
449+ cache_search_results , return_fields # type: ignore
450+ )
451+ # Extend TTL on keys
452+ asyncio .gather (* [self ._async_refresh_ttl (key ) for key in redis_keys ])
461453
462- for cache_search_result in cache_search_results :
463- key = cache_search_result ["id" ]
464- await self ._async_refresh_ttl (key )
454+ return cache_hits
465455
466- # Create cache hit
456+ def _process_cache_results (
457+ self , cache_search_results : List [Dict [str , Any ]], return_fields : List [str ]
458+ ):
459+ redis_keys : List [str ] = []
460+ cache_hits : List [Dict [Any , str ]] = []
461+ for cache_search_result in cache_search_results :
462+ # Pop the redis key from the result
463+ redis_key = cache_search_result .pop ("id" )
464+ redis_keys .append (redis_key )
465+ # Create and process cache hit
467466 cache_hit = CacheHit (** cache_search_result )
468- cache_hit_dict = {
469- k : v for k , v in cache_hit .to_dict ().items () if k in return_fields
470- }
471- cache_hit_dict ["key" ] = key
467+ cache_hit_dict = cache_hit .to_dict ()
468+ # Filter down to only selected return fields if needed
469+ if isinstance (return_fields , list ) and len (return_fields ) > 0 :
470+ cache_hit_dict = {
471+ k : v for k , v in cache_hit_dict .items () if k in return_fields
472+ }
473+ cache_hit_dict [self .redis_key_field_name ] = redis_key
472474 cache_hits .append (cache_hit_dict )
473-
474- return cache_hits
475+ return redis_keys , cache_hits
475476
476477 def store (
477478 self ,
0 commit comments