11# -*- coding: utf-8 -*-
22from typing import List
3- < << << << HEAD
4-
5- import numpy as np
6- from modelcache .manager .vector_data .base import VectorBase , VectorData
7- from modelcache .utils import import_redis
8- from redis .commands .search .query import Query
9- from redis .commands .search .indexDefinition import IndexDefinition , IndexType
10- from modelcache .utils .log import modelcache_log
11-
12- import_redis ()
13- #
14- # from redis.commands.search.indexDefinition import IndexDefinition, IndexType
15- # from redis.commands.search.query import Query
16- # from redis.commands.search.field import TagField, VectorField
17- # from redis.client import Redis
18- == == == =
193import numpy as np
204from redis .commands .search .indexDefinition import IndexDefinition , IndexType
215from redis .commands .search .query import Query
2812from modelcache .utils .index_util import get_index_name
2913from modelcache .utils .index_util import get_index_prefix
3014import_redis ()
31- > >> >> >> main
3215
3316
3417class RedisVectorStore (VectorBase ):
@@ -39,103 +22,25 @@ def __init__(
3922 username : str = "" ,
4023 password : str = "" ,
4124 dimension : int = 0 ,
42- << << << < HEAD
43- collection_name : str = "gptcache" ,
44- top_k : int = 1 ,
45- namespace : str = "" ,
46- ):
47- == == == =
4825 top_k : int = 1 ,
4926 namespace : str = "" ,
5027 ):
5128 if dimension <= 0 :
5229 raise ValueError (
5330 f"invalid `dim` param: { dimension } in the Redis vector store."
5431 )
55- >> >> >> > main
5632 self ._client = Redis (
5733 host = host , port = int (port ), username = username , password = password
5834 )
5935 self .top_k = top_k
6036 self .dimension = dimension
61- < << << << HEAD
62- self .collection_name = collection_name
63- self .namespace = namespace
64- self .doc_prefix = f"{ self .namespace } doc:" # Prefix with the specified namespace
65- self ._create_collection (collection_name )
66- == == == =
6737 self .namespace = namespace
6838 self .doc_prefix = f"{ self .namespace } doc:"
69- >> >> >> > main
7039
7140 def _check_index_exists (self , index_name : str ) -> bool :
7241 """Check if Redis index exists."""
7342 try :
7443 self ._client .ft (index_name ).info ()
75- << << << < HEAD
76- except : # pylint: disable=W0702
77- gptcache_log .info ("Index does not exist" )
78- return False
79- gptcache_log .info ("Index already exists" )
80- return True
81-
82- def _create_collection (self , collection_name ):
83- if self ._check_index_exists (collection_name ):
84- gptcache_log .info (
85- "The %s already exists, and it will be used directly" , collection_name
86- )
87- else :
88- schema = (
89- TagField ("tag" ), # Tag Field Name
90- VectorField (
91- "vector" , # Vector Field Name
92- "FLAT" ,
93- { # Vector Index Type: FLAT or HNSW
94- "TYPE" : "FLOAT32" , # FLOAT32 or FLOAT64
95- "DIM" : self .dimension , # Number of Vector Dimensions
96- "DISTANCE_METRIC" : "COSINE" , # Vector Search Distance Metric
97- },
98- ),
99- )
100- definition = IndexDefinition (
101- prefix = [self .doc_prefix ], index_type = IndexType .HASH
102- )
103-
104- # create Index
105- self ._client .ft (collection_name ).create_index (
106- fields = schema , definition = definition
107- )
108-
109- def mul_add (self , datas : List [VectorData ]):
110- pipe = self ._client .pipeline ()
111-
112- for data in datas :
113- key : int = data .id
114- obj = {
115- "vector" : data .data .astype (np .float32 ).tobytes (),
116- }
117- pipe .hset (f"{ self .doc_prefix } { key } " , mapping = obj )
118-
119- pipe .execute ()
120-
121- def search (self , data : np .ndarray , top_k : int = - 1 ):
122- query = (
123- Query (
124- f"*=>[KNN { top_k if top_k > 0 else self .top_k } @vector $vec as score]"
125- )
126- .sort_by ("score" )
127- .return_fields ("id" , "score" )
128- .paging (0 , top_k if top_k > 0 else self .top_k )
129- .dialect (2 )
130- )
131- query_params = {"vec" : data .astype (np .float32 ).tobytes ()}
132- results = (
133- self ._client .ft (self .collection_name )
134- .search (query , query_params = query_params )
135- .docs
136- )
137- return [(float (result .score ), int (result .id [len (self .doc_prefix ):])) for result in results ]
138- == == == =
13944 except :
14045 modelcache_log .info ("Index does not exist" )
14146 return False
@@ -201,13 +106,10 @@ def search(self, data: np.ndarray, top_k: int = -1, model=None):
201106 .docs
202107 )
203108 return [(float (result .distance ), int (getattr (result , id_field_name ))) for result in results ]
204- > >> >> >> main
205109
206110 def rebuild (self , ids = None ) -> bool :
207111 pass
208112
209- < << << << HEAD
210- == == == =
211113 def rebuild_col (self , model ):
212114 index_name_model = get_index_name (model )
213115 if self ._check_index_exists (index_name_model ):
@@ -222,14 +124,10 @@ def rebuild_col(self, model):
222124 raise ValueError (str (e ))
223125 # return 'rebuild success'
224126
225- >> >> >> > main
226127 def delete (self , ids ) -> None :
227128 pipe = self ._client .pipeline ()
228129 for data_id in ids :
229130 pipe .delete (f"{ self .doc_prefix } { data_id } " )
230- < << << << HEAD
231- pipe .execute ()
232- == == == =
233131 pipe .execute ()
234132
235133 def create (self , model = None ):
@@ -239,4 +137,3 @@ def create(self, model=None):
239137
240138 def get_index_by_name (self , index_name ):
241139 pass
242- > >> >> >> main
0 commit comments