2727 #==================== Cache class definition =========================#
2828 #=====================================================================#
2929
30- executor = ThreadPoolExecutor (max_workers = 2 )
31-
32- def response_text (cache_resp ):
33- return cache_resp ['data' ]
34-
35- def response_hitquery (cache_resp ):
36- return cache_resp ['hitQuery' ]
3730
3831# noinspection PyMethodMayBeStatic
3932class Cache :
@@ -80,11 +73,16 @@ def close():
8073 modelcache_log .error (e )
8174
8275 def save_query_resp (self , query_resp_dict , ** kwargs ):
83- self .data_manager .save_query_resp (query_resp_dict , ** kwargs )
76+ asyncio .create_task (asyncio .to_thread (
77+ self .data_manager .save_query_resp ,
78+ query_resp_dict , ** kwargs
79+ ))
8480
8581 def save_query_info (self ,result , model , query , delta_time_log ):
86- self .data_manager .save_query_resp (result , model = model , query = json .dumps (query , ensure_ascii = False ),
87- delta_time = delta_time_log )
82+ asyncio .create_task (asyncio .to_thread (
83+ self .data_manager .save_query_resp ,
84+ result , model = model , query = json .dumps (query , ensure_ascii = False ), delta_time = delta_time_log
85+ ))
8886
8987 async def handle_request (self , param_dict : dict ):
9088 # param parsing
@@ -103,7 +101,7 @@ async def handle_request(self, param_dict: dict):
103101 result = {"errorCode" : 102 ,
104102 "errorDesc" : "type exception, should one of ['query', 'insert', 'remove', 'register']" ,
105103 "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' , "answer" : '' }
106- self .data_manager . save_query_resp (result , model = model , query = '' , delta_time = 0 )
104+ self .save_query_resp (result , model = model , query = '' , delta_time = 0 )
107105 return result
108106 except Exception as e :
109107 return {"errorCode" : 103 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' ,
@@ -120,14 +118,14 @@ async def handle_request(self, param_dict: dict):
120118 elif request_type == 'insert' :
121119 return await self .handle_insert (chat_info , model )
122120 elif request_type == 'remove' :
123- return self .handle_remove (model , param_dict )
121+ return await self .handle_remove (model , param_dict )
124122 elif request_type == 'register' :
125- return self .handle_register (model )
123+ return await self .handle_register (model )
126124 else :
127125 return {"errorCode" : 400 , "errorDesc" : "bad request" }
128126
129- def handle_register (self , model ):
130- response = adapter .ChatCompletion .create_register (
127+ async def handle_register (self , model ):
128+ response = await adapter .ChatCompletion .create_register (
131129 model = model ,
132130 cache_obj = self
133131 )
@@ -137,10 +135,10 @@ def handle_register(self, model):
137135 result = {"errorCode" : 502 , "errorDesc" : "" , "response" : response , "writeStatus" : "exception" }
138136 return result
139137
140- def handle_remove (self , model , param_dict ):
138+ async def handle_remove (self , model , param_dict ):
141139 remove_type = param_dict .get ("remove_type" )
142140 id_list = param_dict .get ("id_list" , [])
143- response = adapter .ChatCompletion .create_remove (
141+ response = await adapter .ChatCompletion .create_remove (
144142 model = model ,
145143 remove_type = remove_type ,
146144 id_list = id_list ,
@@ -191,12 +189,12 @@ async def handle_query(self, model, query):
191189 result = {"errorCode" : 201 , "errorDesc" : response , "cacheHit" : False , "delta_time" : delta_time ,
192190 "hit_query" : '' , "answer" : '' }
193191 else :
194- answer = response_text ( response )
195- hit_query = response_hitquery ( response )
192+ answer = response [ 'data' ]
193+ hit_query = response [ 'hitQuery' ]
196194 result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : True , "delta_time" : delta_time ,
197195 "hit_query" : hit_query , "answer" : answer }
198196 delta_time_log = round (time .time () - start_time , 2 )
199- executor . submit ( self .save_query_info , result , model , query , delta_time_log )
197+ self .save_query_info ( result , model , query , delta_time_log )
200198 except Exception as e :
201199 result = {"errorCode" : 202 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 ,
202200 "hit_query" : '' , "answer" : '' }
@@ -265,7 +263,9 @@ async def init(
265263 #==================================================#
266264
267265 # switching based on embedding_model
268- if embedding_model == EmbeddingModel .HUGGINGFACE_ALL_MPNET_BASE_V2 :
266+ if (embedding_model == EmbeddingModel .HUGGINGFACE_ALL_MPNET_BASE_V2
267+ or embedding_model == EmbeddingModel .HUGGINGFACE_ALL_MINILM_L6_V2
268+ or embedding_model == EmbeddingModel .HUGGINGFACE_ALL_MINILM_L12_V2 ):
269269 query_pre_embedding_func = query_with_role
270270 insert_pre_embedding_func = query_with_role
271271 post_process_messages_func = first
@@ -287,8 +287,8 @@ async def init(
287287
288288 # add more configurations for other embedding models as needed
289289 else :
290- modelcache_log .error (f"Please add configuration for { embedding_model } in modelcache/__init__ .py." )
291- raise CacheError (f"Please add configuration for { embedding_model } in modelcache/__init__ .py." )
290+ modelcache_log .error (f"Please add configuration for { embedding_model } in modelcache/cache .py." )
291+ raise CacheError (f"Please add configuration for { embedding_model } in modelcache/cache .py." )
292292
293293 # ====================== Data manager ==============================#
294294
@@ -300,7 +300,7 @@ async def init(
300300 config = vector_config ,
301301 metric_type = similarity_metric_type ,
302302 ),
303- eviction = 'ARC' ,
303+ memory_cache_policy = 'ARC' ,
304304 max_size = 10000 ,
305305 normalize = normalize ,
306306 )
0 commit comments