99from pydantic import BaseModel
1010from concurrent .futures import ThreadPoolExecutor
1111from starlette .responses import PlainTextResponse
12+ import functools
1213
1314from modelcache import cache
1415from modelcache .adapter import adapter
@@ -63,9 +64,10 @@ class RequestData(BaseModel):
6364executor = ThreadPoolExecutor (max_workers = 6 )
6465
6566# 异步保存查询信息
66- async def save_query_info (result , model , query , delta_time_log ):
67+ async def save_query_info_fastapi (result , model , query , delta_time_log ):
6768 loop = asyncio .get_running_loop ()
68- await loop .run_in_executor (executor , cache .data_manager .save_query_resp , result , model , json .dumps (query , ensure_ascii = False ), delta_time_log )
69+ func = functools .partial (cache .data_manager .save_query_resp , result , model = model , query = json .dumps (query , ensure_ascii = False ), delta_time = delta_time_log )
70+ await loop .run_in_executor (None , func )
6971
7072
7173
@@ -74,20 +76,53 @@ async def first_fastapi():
7476 return "hello, modelcache!"
7577
7678@app .post ("/modelcache" )
77- async def user_backend (request_data : RequestData ):
78- # param parsing
79+ async def user_backend (request : Request ):
7980 try :
80- request_type = request_data .type
81+ raw_body = await request .body ()
82+ # 解析字符串为JSON对象
83+ if isinstance (raw_body , bytes ):
84+ raw_body = raw_body .decode ("utf-8" )
85+ if isinstance (raw_body , str ):
86+ try :
87+ # 尝试将字符串解析为JSON对象
88+ request_data = json .loads (raw_body )
89+ except json .JSONDecodeError :
90+ # 如果无法解析,返回格式错误
91+ raise HTTPException (status_code = 400 , detail = "Invalid JSON format" )
92+ else :
93+ request_data = raw_body
94+
95+ # 确保request_data是字典对象
96+ if isinstance (request_data , str ):
97+ try :
98+ request_data = json .loads (request_data )
99+ except json .JSONDecodeError :
100+ raise HTTPException (status_code = 400 , detail = "Invalid JSON format" )
101+
102+ request_type = request_data .get ('type' )
81103 model = None
82- if request_data .scope :
83- model = request_data .scope .get ('model' , '' ).replace ('-' ,'_' ).replace ('.' , '_' )
84- query = request_data .query
85- chat_info = request_data .chat_info
104+ if 'scope' in request_data :
105+ model = request_data ['scope' ].get ('model' , '' ).replace ('-' , '_' ).replace ('.' , '_' )
106+ query = request_data .get ('query' )
107+ chat_info = request_data .get ('chat_info' )
108+
109+ if not request_type or request_type not in ['query' , 'insert' , 'remove' , 'detox' ]:
110+ raise HTTPException (status_code = 400 , detail = "Type exception, should be one of ['query', 'insert', 'remove', 'detox']" )
86111
87112 except Exception as e :
88- result = {"errorCode" : 103 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' , "answer" : '' }
113+ request_data = raw_body if 'raw_body' in locals () else None
114+ result = {
115+ "errorCode" : 103 ,
116+ "errorDesc" : str (e ),
117+ "cacheHit" : False ,
118+ "delta_time" : 0 ,
119+ "hit_query" : '' ,
120+ "answer" : '' ,
121+ "para_dict" : request_data
122+ }
89123 return result
90124
125+
91126 # model filter
92127 filter_resp = model_blacklist_filter (model , request_type )
93128 if isinstance (filter_resp , dict ):
@@ -101,8 +136,7 @@ async def user_backend(request_data: RequestData):
101136
102137 if response is None :
103138 result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : False , "delta_time" : delta_time , "hit_query" : '' , "answer" : '' }
104- # elif response in ['adapt_query_exception']:
105- elif isinstance (response , str ):
139+ elif response in ['adapt_query_exception' ]:
106140 result = {"errorCode" : 201 , "errorDesc" : response , "cacheHit" : False , "delta_time" : delta_time ,
107141 "hit_query" : '' , "answer" : '' }
108142 else :
@@ -111,7 +145,7 @@ async def user_backend(request_data: RequestData):
111145 result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : True , "delta_time" : delta_time , "hit_query" : hit_query , "answer" : answer }
112146
113147 delta_time_log = round (time .time () - start_time , 2 )
114- asyncio .create_task (save_query_info (result , model , query , delta_time_log ))
148+ asyncio .create_task (save_query_info_fastapi (result , model , query , delta_time_log ))
115149 return result
116150 except Exception as e :
117151 result = {"errorCode" : 202 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 ,
@@ -130,7 +164,7 @@ async def user_backend(request_data: RequestData):
130164 return {"errorCode" : 303 , "errorDesc" : str (e ), "writeStatus" : "exception" }
131165
132166 if request_type == 'remove' :
133- response = adapter .ChatCompletion .create_remove (model = model , remove_type = request_data .remove_type , id_list = request_data .id_list )
167+ response = adapter .ChatCompletion .create_remove (model = model , remove_type = request_data .get ( " remove_type" ) , id_list = request_data .get ( " id_list" ) )
134168 if not isinstance (response , dict ):
135169 return {"errorCode" : 401 , "errorDesc" : "" , "response" : response , "removeStatus" : "exception" }
136170
0 commit comments