11# -*- coding: utf-8 -*-
22import os
33import time
4+ import uuid
45
56import pymysql
67import json
@@ -42,26 +43,59 @@ def _insert(self, data: List):
4243 answer_type = 0
4344 embedding_data = embedding_data .tobytes ()
4445 is_deleted = 0
46+ _id = str (uuid .uuid4 ())
4547
4648 table_name = "modelcache_llm_answer"
47- insert_sql = "INSERT INTO {} (question, answer, answer_type, model, embedding_data, is_deleted) VALUES (%s, %s, %s, %s, _binary%s, %s)" .format (table_name )
49+ insert_sql = f"""
50+ INSERT INTO { table_name }
51+ (id, question, answer, answer_type, model, embedding_data, is_deleted)
52+ VALUES (%s, %s, %s, %s, %s, _binary%s, %s)
53+ """
4854 conn = self .pool .connection ()
4955 try :
5056 with conn .cursor () as cursor :
5157 # 执行插入数据操作
52- values = (question , answer , answer_type , model , embedding_data , is_deleted )
58+ values = (_id , question , answer , answer_type , model , embedding_data , is_deleted )
5359 cursor .execute (insert_sql , values )
5460 conn .commit ()
55- id = cursor .lastrowid
5661 finally :
5762 # 关闭连接,将连接返回给连接池
5863 conn .close ()
59- return id
64+ return _id
6065
61- def batch_insert (self , all_data : List [CacheData ]):
66+ def batch_insert (self , all_data : List [List ]):
67+ table_name = "modelcache_llm_answer"
68+ insert_sql = f"""
69+ INSERT INTO { table_name }
70+ (id, question, answer, answer_type, model, embedding_data, is_deleted)
71+ VALUES (%s, %s, %s, %s, %s, %s, %s)
72+ """
73+
74+ values_list = []
6275 ids = []
76+
6377 for data in all_data :
64- ids .append (self ._insert (data ))
78+ answer = data [0 ]
79+ question = data [1 ]
80+ embedding_data = data [2 ].tobytes ()
81+ model = data [3 ]
82+ answer_type = 0
83+ is_deleted = 0
84+ _id = str (uuid .uuid4 ())
85+ ids .append (_id )
86+
87+ values_list .append ((
88+ _id , question , answer , answer_type , model , embedding_data , is_deleted
89+ ))
90+
91+ conn = self .pool .connection ()
92+ try :
93+ with conn .cursor () as cursor :
94+ cursor .executemany (insert_sql , values_list )
95+ conn .commit ()
96+ finally :
97+ conn .close ()
98+
6599 return ids
66100
67101 def insert_query_resp (self , query_resp , ** kwargs ):
@@ -78,7 +112,11 @@ def insert_query_resp(self, query_resp, **kwargs):
78112 hit_query = json .dumps (hit_query , ensure_ascii = False )
79113
80114 table_name = "modelcache_query_log"
81- insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)" .format (table_name )
115+ insert_sql = f"""
116+ INSERT INTO { table_name }
117+ (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer)
118+ VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
119+ """
82120 conn = self .pool .connection ()
83121 try :
84122 with conn .cursor () as cursor :
@@ -92,15 +130,16 @@ def insert_query_resp(self, query_resp, **kwargs):
92130
93131 def get_data_by_id (self , key : int ):
94132 table_name = "modelcache_llm_answer"
95- query_sql = "select question, answer, embedding_data, model from {} where id={}" .format (table_name , key )
96- conn_start = time .time ()
133+ query_sql = f"""
134+ SELECT question, answer, embedding_data, model
135+ FROM { table_name }
136+ WHERE id = %s
137+ """
97138 conn = self .pool .connection ()
98-
99- search_start = time .time ()
100139 try :
101140 with conn .cursor () as cursor :
102141 # 执行数据库操作
103- cursor .execute (query_sql )
142+ cursor .execute (query_sql , ( key ,) )
104143 resp = cursor .fetchone ()
105144 finally :
106145 # 关闭连接,将连接返回给连接池
@@ -113,14 +152,18 @@ def get_data_by_id(self, key: int):
113152
114153 def update_hit_count_by_id (self , primary_id : int ):
115154 table_name = "modelcache_llm_answer"
116- update_sql = "UPDATE {} SET hit_count = hit_count+1 WHERE id={}" .format (table_name , primary_id )
155+ update_sql = f"""
156+ UPDATE { table_name }
157+ SET hit_count = hit_count+1
158+ WHERE id = %s
159+ """
117160 conn = self .pool .connection ()
118161
119162 # 使用连接执行更新数据操作
120163 try :
121164 with conn .cursor () as cursor :
122165 # 执行更新数据操作
123- cursor .execute (update_sql )
166+ cursor .execute (update_sql ,( primary_id ,) )
124167 conn .commit ()
125168 finally :
126169 # 关闭连接,将连接返回给连接池
@@ -129,12 +172,16 @@ def update_hit_count_by_id(self, primary_id: int):
129172 def get_ids (self , deleted = True ):
130173 table_name = "modelcache_llm_answer"
131174 state = 1 if deleted else 0
132- query_sql = "Select id FROM {} WHERE is_deleted = {}" .format (table_name , state )
175+ query_sql = f"""
176+ SELECT id
177+ FROM { table_name }
178+ WHERE is_deleted = %s
179+ """
133180
134181 conn = self .pool .connection ()
135182 try :
136183 with conn .cursor () as cursor :
137- cursor .execute (query_sql )
184+ cursor .execute (query_sql , ( state ,) )
138185 ids = [row [0 ] for row in cursor .fetchall ()]
139186 finally :
140187 conn .close ()
@@ -143,37 +190,45 @@ def get_ids(self, deleted=True):
143190
144191 def mark_deleted (self , keys ):
145192 table_name = "modelcache_llm_answer"
146- mark_sql = " update {} set is_deleted=1 WHERE id in ({})" .format (table_name , "," .join ([str (i ) for i in keys ]))
193+ placeholders = "," .join (["%s" ] * len (keys ))
194+ mark_sql = f"""
195+ UPDATE { table_name }
196+ SET is_deleted=1
197+ WHERE id in ({ placeholders } )
198+ """
147199
148- # 从连接池中获取连接
149200 conn = self .pool .connection ()
150201 try :
151202 with conn .cursor () as cursor :
152- # 执行删除数据操作
153- cursor .execute (mark_sql )
203+ cursor .execute (mark_sql , keys )
154204 delete_count = cursor .rowcount
155205 conn .commit ()
156206 finally :
157- # 关闭连接,将连接返回给连接池
158207 conn .close ()
159208 return delete_count
160209
161210 def model_deleted (self , model_name ):
162211 table_name = "modelcache_llm_answer"
163- delete_sql = "Delete from {} WHERE model='{}'" .format (table_name , model_name )
212+ delete_sql = f"""
213+ Delete from { table_name }
214+ WHERE model = %s
215+ """
164216
165217 table_log_name = "modelcache_query_log"
166- delete_log_sql = "Delete from {} WHERE model='{}'" .format (table_log_name , model_name )
218+ delete_log_sql = f"""
219+ Delete from { table_log_name }
220+ WHERE model = %s
221+ """
167222
168223 conn = self .pool .connection ()
169224 # 使用连接执行删除数据操作
170225 try :
171226 with conn .cursor () as cursor :
172227 # 执行删除数据操作
173- resp = cursor .execute (delete_sql )
228+ resp = cursor .execute (delete_sql , ( model_name ,) )
174229 conn .commit ()
175230 # 执行删除该模型对应日志操作 resp_log行数不返回
176- resp_log = cursor .execute (delete_log_sql )
231+ resp_log = cursor .execute (delete_log_sql , ( model_name ,))
177232 conn .commit () # 分别提交事务
178233 finally :
179234 # 关闭连接,将连接返回给连接池
@@ -182,7 +237,10 @@ def model_deleted(self, model_name):
182237
183238 def clear_deleted_data (self ):
184239 table_name = "modelcache_llm_answer"
185- delete_sql = "DELETE FROM {} WHERE is_deleted = 1" .format (table_name )
240+ delete_sql = f"""
241+ DELETE FROM { table_name }
242+ WHERE is_deleted = 1
243+ """
186244
187245 conn = self .pool .connection ()
188246 try :
@@ -197,10 +255,15 @@ def clear_deleted_data(self):
197255
198256 def count (self , state : int = 0 , is_all : bool = False ):
199257 table_name = "modelcache_llm_answer"
258+
259+ # we're not using prepared statements here, so we need to ensure state is an integer
260+ if not isinstance (state , int ):
261+ raise ValueError ("'state' must be an integer." )
262+
200263 if is_all :
201- count_sql = "SELECT COUNT(*) FROM {}" . format ( table_name )
264+ count_sql = f "SELECT COUNT(*) FROM { table_name } "
202265 else :
203- count_sql = "SELECT COUNT(*) FROM {} WHERE is_deleted = {}" . format ( table_name , state )
266+ count_sql = f "SELECT COUNT(*) FROM { table_name } WHERE is_deleted = { state } "
204267
205268 conn = self .pool .connection ()
206269 try :
0 commit comments