77import cachetools
88from abc import abstractmethod , ABCMeta
99from typing import List , Any , Optional , Union
10+
11+ from numpy import ndarray
12+
1013from modelcache .manager .scalar_data .base import (
1114 CacheStorage ,
1215 CacheData ,
2124from modelcache .manager .eviction_manager import EvictionManager
2225from modelcache .utils .log import modelcache_log
2326
27+ NORMALIZE = True
2428
2529class DataManager (metaclass = ABCMeta ):
2630 """DataManager manage the cache data, including save and search"""
@@ -158,9 +162,9 @@ def __init__(
158162 self .v = v
159163 self .o = o
160164
161- def save (self , question , answer , embedding_data , ** kwargs ):
165+ def save (self , questions : List [ any ], answers : List [ any ], embedding_datas : List [ any ] , ** kwargs ):
162166 model = kwargs .pop ("model" , None )
163- self .import_data ([ question ], [ answer ], [ embedding_data ] , model )
167+ self .import_data (questions , answers , embedding_datas , model )
164168
165169 def save_query_resp (self , query_resp_dict , ** kwargs ):
166170 save_query_start_time = time .time ()
@@ -197,9 +201,10 @@ def import_data(
197201 raise ParamError ("Make sure that all parameters have the same length" )
198202 cache_datas = []
199203
200- embedding_datas = [
201- normalize (embedding_data ) for embedding_data in embedding_datas
202- ]
204+ if NORMALIZE :
205+ embedding_datas = [
206+ normalize (embedding_data ) for embedding_data in embedding_datas
207+ ]
203208
204209 for i , embedding_data in enumerate (embedding_datas ):
205210 if self .o is not None :
@@ -212,11 +217,9 @@ def import_data(
212217 cache_datas .append ([ans , question , embedding_data , model ])
213218
214219 ids = self .s .batch_insert (cache_datas )
220+ datas_ = [VectorData (id = ids [i ], data = embedding_data .astype ("float32" )) for i , embedding_data in enumerate (embedding_datas )]
215221 self .v .mul_add (
216- [
217- VectorData (id = ids [i ], data = embedding_data )
218- for i , embedding_data in enumerate (embedding_datas )
219- ],
222+ datas_ ,
220223 model
221224
222225 )
@@ -235,7 +238,8 @@ def hit_cache_callback(self, res_data, **kwargs):
235238
236239 def search (self , embedding_data , ** kwargs ):
237240 model = kwargs .pop ("model" , None )
238- embedding_data = normalize (embedding_data )
241+ if NORMALIZE :
242+ embedding_data = normalize (embedding_data )
239243 top_k = kwargs .get ("top_k" , - 1 )
240244 return self .v .search (data = embedding_data , top_k = top_k , model = model )
241245
0 commit comments