1+ # -*- coding: utf-8 -*-
2+ import time
3+ import uvicorn
4+ import asyncio
5+ import logging
6+ import configparser
7+ import json
8+ from fastapi import FastAPI , Request , HTTPException
9+ from pydantic import BaseModel
10+ from concurrent .futures import ThreadPoolExecutor
11+ from starlette .responses import PlainTextResponse
12+ import functools
13+
14+ from modelcache import cache
15+ from modelcache .adapter import adapter
16+ from modelcache .manager import CacheBase , VectorBase , get_data_manager
17+ from modelcache .similarity_evaluation .distance import SearchDistanceEvaluation
18+ from modelcache .processor .pre import query_multi_splicing
19+ from modelcache .processor .pre import insert_multi_splicing
20+ from modelcache .utils .model_filter import model_blacklist_filter
21+ from modelcache .embedding import Data2VecAudio
22+
23+ #创建一个FastAPI实例
24+ app = FastAPI ()
25+
26+ class RequestData (BaseModel ):
27+ type : str
28+ scope : dict = None
29+ query : str = None
30+ chat_info : dict = None
31+ remove_type : str = None
32+ id_list : list = []
33+
34+ data2vec = Data2VecAudio ()
35+ mysql_config = configparser .ConfigParser ()
36+ mysql_config .read ('modelcache/config/mysql_config.ini' )
37+
38+ milvus_config = configparser .ConfigParser ()
39+ milvus_config .read ('modelcache/config/milvus_config.ini' )
40+
41+ # redis_config = configparser.ConfigParser()
42+ # redis_config.read('modelcache/config/redis_config.ini')
43+
44+ # 初始化datamanager
45+ data_manager = get_data_manager (
46+ CacheBase ("mysql" , config = mysql_config ),
47+ VectorBase ("milvus" , dimension = data2vec .dimension , milvus_config = milvus_config )
48+ )
49+
50+ # # 使用redis初始化datamanager
51+ # data_manager = get_data_manager(
52+ # CacheBase("mysql", config=mysql_config),
53+ # VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config)
54+ # )
55+
56+ cache .init (
57+ embedding_func = data2vec .to_embeddings ,
58+ data_manager = data_manager ,
59+ similarity_evaluation = SearchDistanceEvaluation (),
60+ query_pre_embedding_func = query_multi_splicing ,
61+ insert_pre_embedding_func = insert_multi_splicing ,
62+ )
63+
64+ executor = ThreadPoolExecutor (max_workers = 6 )
65+
66+ # 异步保存查询信息
67+ async def save_query_info (result , model , query , delta_time_log ):
68+ loop = asyncio .get_running_loop ()
69+ func = functools .partial (cache .data_manager .save_query_resp , result , model = model , query = json .dumps (query , ensure_ascii = False ), delta_time = delta_time_log )
70+ await loop .run_in_executor (None , func )
71+
72+
73+
74+ @app .get ("/welcome" , response_class = PlainTextResponse )
75+ async def first_fastapi ():
76+ return "hello, modelcache!"
77+
78+ @app .post ("/modelcache" )
79+ async def user_backend (request : Request ):
80+ try :
81+ raw_body = await request .body ()
82+ # 解析字符串为JSON对象
83+ if isinstance (raw_body , bytes ):
84+ raw_body = raw_body .decode ("utf-8" )
85+ if isinstance (raw_body , str ):
86+ try :
87+ # 尝试将字符串解析为JSON对象
88+ request_data = json .loads (raw_body )
89+ except json .JSONDecodeError as e :
90+ # 如果无法解析,返回格式错误
91+ result = {"errorCode" : 101 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' ,
92+ "answer" : '' }
93+ asyncio .create_task (save_query_info (result , model = '' , query = '' , delta_time_log = 0 ))
94+ raise HTTPException (status_code = 101 , detail = "Invalid JSON format" )
95+ else :
96+ request_data = raw_body
97+
98+ # 确保request_data是字典对象
99+ if isinstance (request_data , str ):
100+ try :
101+ request_data = json .loads (request_data )
102+ except json .JSONDecodeError :
103+ raise HTTPException (status_code = 101 , detail = "Invalid JSON format" )
104+
105+ request_type = request_data .get ('type' )
106+ model = None
107+ if 'scope' in request_data :
108+ model = request_data ['scope' ].get ('model' , '' ).replace ('-' , '_' ).replace ('.' , '_' )
109+ query = request_data .get ('query' )
110+ chat_info = request_data .get ('chat_info' )
111+
112+ if not request_type or request_type not in ['query' , 'insert' , 'remove' , 'register' ]:
113+ result = {"errorCode" : 102 ,
114+ "errorDesc" : "type exception, should one of ['query', 'insert', 'remove', 'register']" ,
115+ "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' , "answer" : '' }
116+ asyncio .create_task (save_query_info (result , model = model , query = '' , delta_time_log = 0 ))
117+ raise HTTPException (status_code = 102 , detail = "Type exception, should be one of ['query', 'insert', 'remove', 'register']" )
118+
119+ except Exception as e :
120+ request_data = raw_body if 'raw_body' in locals () else None
121+ result = {
122+ "errorCode" : 103 ,
123+ "errorDesc" : str (e ),
124+ "cacheHit" : False ,
125+ "delta_time" : 0 ,
126+ "hit_query" : '' ,
127+ "answer" : '' ,
128+ "para_dict" : request_data
129+ }
130+ return result
131+
132+
133+ # model filter
134+ filter_resp = model_blacklist_filter (model , request_type )
135+ if isinstance (filter_resp , dict ):
136+ return filter_resp
137+
138+ if request_type == 'query' :
139+ try :
140+ start_time = time .time ()
141+ response = adapter .ChatCompletion .create_query (scope = {"model" : model }, query = query )
142+ delta_time = f"{ round (time .time () - start_time , 2 )} s"
143+
144+ if response is None :
145+ result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : False , "delta_time" : delta_time , "hit_query" : '' , "answer" : '' }
146+ elif response in ['adapt_query_exception' ]:
147+ result = {"errorCode" : 201 , "errorDesc" : response , "cacheHit" : False , "delta_time" : delta_time ,
148+ "hit_query" : '' , "answer" : '' }
149+ else :
150+ answer = response ['data' ]
151+ hit_query = response ['hitQuery' ]
152+ result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : True , "delta_time" : delta_time , "hit_query" : hit_query , "answer" : answer }
153+
154+ delta_time_log = round (time .time () - start_time , 2 )
155+ asyncio .create_task (save_query_info (result , model , query , delta_time_log ))
156+ return result
157+ except Exception as e :
158+ result = {"errorCode" : 202 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 ,
159+ "hit_query" : '' , "answer" : '' }
160+ logging .info (f'result: { str (result )} ' )
161+ return result
162+
163+ if request_type == 'insert' :
164+ try :
165+ response = adapter .ChatCompletion .create_insert (model = model , chat_info = chat_info )
166+ if response == 'success' :
167+ return {"errorCode" : 0 , "errorDesc" : "" , "writeStatus" : "success" }
168+ else :
169+ return {"errorCode" : 301 , "errorDesc" : response , "writeStatus" : "exception" }
170+ except Exception as e :
171+ return {"errorCode" : 303 , "errorDesc" : str (e ), "writeStatus" : "exception" }
172+
173+ if request_type == 'remove' :
174+ response = adapter .ChatCompletion .create_remove (model = model , remove_type = request_data .get ("remove_type" ), id_list = request_data .get ("id_list" ))
175+ if not isinstance (response , dict ):
176+ return {"errorCode" : 401 , "errorDesc" : "" , "response" : response , "removeStatus" : "exception" }
177+
178+ state = response .get ('status' )
179+ if state == 'success' :
180+ return {"errorCode" : 0 , "errorDesc" : "" , "response" : response , "writeStatus" : "success" }
181+ else :
182+ return {"errorCode" : 402 , "errorDesc" : "" , "response" : response , "writeStatus" : "exception" }
183+
184+ if request_type == 'register' :
185+ response = adapter .ChatCompletion .create_register (model = model )
186+ if response in ['create_success' , 'already_exists' ]:
187+ return {"errorCode" : 0 , "errorDesc" : "" , "response" : response , "writeStatus" : "success" }
188+ else :
189+ return {"errorCode" : 502 , "errorDesc" : "" , "response" : response , "writeStatus" : "exception" }
190+
191+ # TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
192+ if __name__ == '__main__' :
193+ uvicorn .run (app , host = '0.0.0.0' , port = 5000 )
0 commit comments