11# -*- coding: utf-8 -*-
2+ import asyncio
3+ from contextlib import asynccontextmanager
24import uvicorn
35import json
4- from fastapi import FastAPI , Request , HTTPException
6+ from fastapi .responses import JSONResponse
7+ from fastapi import FastAPI , Request
58from modelcache .cache import Cache
6-
7- #创建一个FastAPI实例
8- app = FastAPI ()
9-
10- cache = Cache .init ("mysql" , "milvus" )
9+ from modelcache .embedding import EmbeddingModel
10+
11+ @asynccontextmanager
12+ async def lifespan (app : FastAPI ):
13+ global cache
14+ cache , _ = await Cache .init (
15+ sql_storage = "mysql" ,
16+ vector_storage = "milvus" ,
17+ embedding_model = EmbeddingModel .HUGGINGFACE_ALL_MPNET_BASE_V2 ,
18+ embedding_workers_num = 2
19+ )
20+ yield
21+
22+ app = FastAPI (lifespan = lifespan )
23+ cache : Cache = None
1124
1225@app .get ("/welcome" )
1326async def first_fastapi ():
1427 return "hello, modelcache!"
1528
16-
1729@app .post ("/modelcache" )
1830async def user_backend (request : Request ):
19- try :
20- raw_body = await request .body ()
21- # 解析字符串为JSON对象
22- if isinstance (raw_body , bytes ):
23- raw_body = raw_body .decode ("utf-8" )
24- if isinstance (raw_body , str ):
25- try :
26- # 尝试将字符串解析为JSON对象
27- request_data = json .loads (raw_body )
28- except json .JSONDecodeError as e :
29- # 如果无法解析,返回格式错误
30- result = {"errorCode" : 101 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' ,
31- "answer" : '' }
32- cache .save_query_info (result , model = '' , query = '' , delta_time_log = 0 )
33- raise HTTPException (status_code = 101 , detail = "Invalid JSON format" )
34- else :
35- request_data = raw_body
3631
37- # 确保request_data是字典对象
38- if isinstance (request_data , str ):
39- try :
40- request_data = json .loads (request_data )
41- except json .JSONDecodeError :
42- raise HTTPException (status_code = 101 , detail = "Invalid JSON format" )
43-
44- return cache .handle_request (request_data )
32+ try :
33+ request_data = await request .json ()
34+ except Exception :
35+ result = {"errorCode" : 400 , "errorDesc" : "bad request" , "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' , "answer" : '' }
36+ return JSONResponse (status_code = 400 , content = result )
4537
38+ try :
39+ return await cache .handle_request (request_data )
4640 except Exception as e :
47- request_data = raw_body if 'raw_body' in locals () else None
48- result = {
49- "errorCode" : 103 ,
50- "errorDesc" : str (e ),
51- "cacheHit" : False ,
52- "delta_time" : 0 ,
53- "hit_query" : '' ,
54- "answer" : '' ,
55- "para_dict" : request_data
56- }
57- return result
41+ result = {"errorCode" : 500 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' , "answer" : '' }
42+ cache .save_query_resp (result , model = '' , query = '' , delta_time = 0 )
43+ return JSONResponse (status_code = 500 , content = result )
5844
59- # TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
6045if __name__ == '__main__' :
61- uvicorn .run (app , host = '0.0.0.0' , port = 5000 )
46+ uvicorn .run (app , host = '0.0.0.0' , port = 5000 , loop = "asyncio" , http = "httptools" )
0 commit comments