@@ -21,19 +21,23 @@ def __init__(
2121 port : str = "6379" ,
2222 username : str = "" ,
2323 password : str = "" ,
24- dimension : int = 0 ,
24+ mm_dimension : int = 0 ,
25+ i_dimension : int = 0 ,
26+ t_dimension : int = 0 ,
2527 top_k : int = 1 ,
2628 namespace : str = "" ,
2729 ):
28- if dimension <= 0 :
30+ if mm_dimension <= 0 :
2931 raise ValueError (
30- f"invalid `dim` param: { dimension } in the Milvus vector store."
32+ f"invalid `dim` param: { mm_dimension } in the Milvus vector store."
3133 )
3234 self ._client = Redis (
3335 host = host , port = int (port ), username = username , password = password
3436 )
3537 self .top_k = top_k
36- self .dimension = dimension
38+ self .mm_dimension = mm_dimension
39+ self .i_dimension = i_dimension
40+ self .t_dimension = t_dimension
3741 self .namespace = namespace
3842 self .doc_prefix = f"{ self .namespace } doc:"
3943
@@ -47,8 +51,16 @@ def _check_index_exists(self, index_name: str) -> bool:
4751 modelcache_log .info ("Index already exists" )
4852 return True
4953
50- def create_index (self , index_name , index_prefix ):
51- dimension = self .dimension
54+ def create_index (self , index_name , mm_type , index_prefix ):
55+ # dimension = self.dimension
56+ if mm_type == 'IMG_TEXT' :
57+ dimension = self .mm_dimension
58+ elif mm_type == 'IMG' :
59+ dimension = self .i_dimension
60+ elif mm_type == 'TEXT' :
61+ dimension = self .t_dimension
62+ else :
63+ raise ValueError ('dimension type exception' )
5264 print ('dimension: {}' .format (dimension ))
5365 if self ._check_index_exists (index_name ):
5466 modelcache_log .info (
@@ -77,13 +89,17 @@ def create_index(self, index_name, index_prefix):
7789 )
7890 return 'create_success'
7991
80- def mul_add (self , datas : List [VectorData ], model = None ):
81- # pipe = self._client.pipeline()
92+ def mul_add (self , datas : List [VectorData ], model = None , mm_type = None ):
8293 for data in datas :
8394 id : int = data .id
8495 embedding = data .data .astype (np .float32 ).tobytes ()
96+
97+ collection_name = get_collection_iat_name (model , mm_type )
98+ index_prefix = get_collection_iat_prefix (model , mm_type )
99+
85100 id_field_name = "data_id"
86101 embedding_field_name = "data_vector"
102+
87103 obj = {id_field_name : id , embedding_field_name : embedding }
88104 index_prefix = get_index_prefix (model )
89105 self ._client .hset (f"{ index_prefix } { id } " , mapping = obj )
0 commit comments