Skip to content

Commit 9fce2e4

Browse files
committed
multimodal insert ability develop
1 parent e3eac54 commit 9fce2e4

File tree

11 files changed

+219
-172
lines changed

11 files changed

+219
-172
lines changed

flask4modelcache.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ def user_backend():
177177
return json.dumps(result)
178178

179179
if request_type == 'register':
180-
# iat_type = param_dict.get("iat_type")
181180
response = adapter.ChatCompletion.create_register(
182181
model=model
183182
)

modelcache/adapter_mm/adapter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
from modelcache.adapter_mm.adapter_query import adapt_query
55
from modelcache.adapter_mm.adapter_insert import adapt_insert
6-
from modelcache.adapter.adapter_remove import adapt_remove
7-
from modelcache.adapter.adapter_register import adapt_register
6+
from modelcache.adapter_mm.adapter_remove import adapt_remove
7+
from modelcache.adapter_mm.adapter_register import adapt_register
88

99

1010
class ChatCompletion(object):
@@ -30,7 +30,8 @@ def create_mm_insert(cls, *args, **kwargs):
3030
**kwargs
3131
)
3232
except Exception as e:
33-
return str(e)
33+
# return str(e)
34+
raise e
3435

3536
@classmethod
3637
def create_mm_remove(cls, *args, **kwargs):
@@ -51,7 +52,7 @@ def create_mm_register(cls, *args, **kwargs):
5152
**kwargs
5253
)
5354
except Exception as e:
54-
return str(e)
55+
raise e
5556

5657

5758
def construct_resp_from_cache(return_message, return_query):

modelcache/adapter_mm/adapter_insert.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# -*- coding: utf-8 -*-
2+
import time
3+
import requests
4+
import base64
5+
import numpy as np
26
from modelcache import cache
37
from modelcache.utils.error import NotInitError
48
from modelcache.utils.time import time_cal
@@ -15,26 +19,77 @@ def adapt_insert(*args, **kwargs):
1519
cache_enable = chat_cache.cache_enable_func(*args, **kwargs)
1620
context = kwargs.pop("cache_context", {})
1721
embedding_data = None
18-
pre_embedding_data = chat_cache.insert_pre_embedding_func(
22+
pre_embedding_data_dict = chat_cache.mm_insert_pre_embedding_func(
1923
kwargs,
2024
extra_param=context.get("pre_embedding_func", None),
2125
prompts=chat_cache.config.prompts,
2226
)
27+
28+
print('pre_embedding_data_dict: {}'.format(pre_embedding_data_dict))
2329
chat_info = kwargs.pop("chat_info", [])
2430
llm_data = chat_info[-1]['answer']
2531

32+
pre_embedding_text = '###'.join(pre_embedding_data_dict['text'])
33+
pre_embedding_image_url = pre_embedding_data_dict['imageUrl']
34+
pre_embedding_image_raw = pre_embedding_data_dict['imageRaw']
35+
pre_embedding_image_id = pre_embedding_data_dict.get('imageId', None)
36+
37+
if pre_embedding_image_url and pre_embedding_image_raw:
38+
raise ValueError("Both pre_embedding_image_url and pre_embedding_image_raw cannot be non-empty at the same time.")
39+
40+
if pre_embedding_image_url:
41+
url_start_time = time.time()
42+
response = requests.get(pre_embedding_image_url)
43+
image_data = response.content
44+
pre_embedding_image = base64.b64encode(image_data).decode('utf-8')
45+
get_image_time = '{}s'.format(round(time.time() - url_start_time, 2))
46+
print('get_image_time: {}'.format(get_image_time))
47+
elif pre_embedding_image_raw:
48+
pre_embedding_image = pre_embedding_image_raw
49+
else:
50+
pre_embedding_image = None
51+
if not pre_embedding_text:
52+
raise ValueError(
53+
"Both pre_embedding_image_url and pre_embedding_image_raw are empty. Please provide at least one.")
54+
55+
data_dict = {'text': [pre_embedding_text], 'image': pre_embedding_image}
56+
embedding_data = None
57+
mm_type = None
58+
2659
if cache_enable:
27-
embedding_data = time_cal(
60+
embedding_data_resp = time_cal(
2861
chat_cache.embedding_func,
29-
func_name="embedding",
62+
func_name="image_embedding",
3063
report_func=chat_cache.report.embedding,
31-
)(pre_embedding_data)
64+
)(data_dict)
65+
66+
image_embeddings = embedding_data_resp['image_embedding']
67+
text_embeddings = embedding_data_resp['text_embeddings']
68+
69+
if len(image_embeddings) > 0 and len(image_embeddings) > 0:
70+
image_embedding = np.array(image_embeddings[0])
71+
text_embedding = text_embeddings[0]
72+
embedding_data = np.concatenate((image_embedding, text_embedding))
73+
mm_type = 'mm'
74+
elif len(image_embeddings) > 0:
75+
image_embedding = np.array(image_embeddings[0])
76+
embedding_data = image_embedding
77+
mm_type = 'image'
78+
elif len(text_embeddings) > 0:
79+
text_embedding = np.array(text_embeddings[0])
80+
embedding_data = text_embedding
81+
mm_type = 'text'
82+
else:
83+
raise ValueError('maya embedding service return both empty list, please check!')
3284

3385
chat_cache.data_manager.save(
34-
pre_embedding_data,
86+
pre_embedding_text,
87+
pre_embedding_image_url,
88+
pre_embedding_image_id,
3589
llm_data,
3690
embedding_data,
3791
model=model,
38-
extra_param=context.get("save_func", None)
92+
mm_type=mm_type,
93+
extra_param=context.get("mm_save_func", None)
3994
)
40-
return 'success'
95+
return 'success'

modelcache/adapter_mm/adapter_register.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
def adapt_register(*args, **kwargs):
66
chat_cache = kwargs.pop("cache_obj", cache)
77
model = kwargs.pop("model", None)
8+
mm_type = kwargs.pop("mm_type", None)
89
if model is None or len(model) == 0:
910
return ValueError('')
1011

11-
register_resp = chat_cache.data_manager.create_index(model)
12+
print('mm_type: {}'.format(mm_type))
13+
print('model: {}'.format(model))
14+
register_resp = chat_cache.data_manager.create_index(model, mm_type)
1215
print('register_resp: {}'.format(register_resp))
1316
return register_resp

modelcache/manager_mm/data_manager.py

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,21 @@
2525
class DataManager(metaclass=ABCMeta):
2626
"""DataManager manage the cache data, including save and search"""
2727

28+
# @abstractmethod
29+
# def save(self, question, answer, embedding_data, **kwargs):
30+
# pass
31+
2832
@abstractmethod
29-
def save(self, question, answer, embedding_data, **kwargs):
33+
def save(self, text, image_url, image_id, answer, embedding, **kwargs):
3034
pass
3135

3236
@abstractmethod
3337
def save_query_resp(self, query_resp_dict, **kwargs):
3438
pass
3539

3640
@abstractmethod
37-
def import_data(
38-
self, questions: List[Any], answers: List[Any], embedding_datas: List[Any], model:Any
39-
):
41+
def import_data(self, texts: List[Any], image_urls: List[Any], image_ids: List[Any], answers: List[Answer],
42+
embeddings: List[Any], model: Any, iat_type: Any):
4043
pass
4144

4245
@abstractmethod
@@ -89,21 +92,20 @@ def init(self):
8992
f"You don't have permission to access this file <{self.data_path}>."
9093
)
9194

92-
def save(self, question, answer, embedding_data, **kwargs):
93-
if isinstance(question, Question):
94-
question = question.content
95-
self.data[embedding_data] = (question, answer, embedding_data)
95+
# def save(self, question, answer, embedding_data, **kwargs):
96+
# if isinstance(question, Question):
97+
# question = question.content
98+
# self.data[embedding_data] = (question, answer, embedding_data)
99+
100+
def save(self, text, image_url, image_id, answer, embedding, **kwargs):
101+
pass
96102

97103
def save_query_resp(self, query_resp_dict, **kwargs):
98104
pass
99105

100-
def import_data(
101-
self, questions: List[Any], answers: List[Any], embedding_datas: List[Any], model: Any
102-
):
103-
if len(questions) != len(answers) or len(questions) != len(embedding_datas):
104-
raise ParamError("Make sure that all parameters have the same length")
105-
for i, embedding_data in enumerate(embedding_datas):
106-
self.data[embedding_data] = (questions[i], answers[i], embedding_datas[i])
106+
def import_data(self, texts: List[Any], image_urls: List[Any], image_ids: List[Any], answers: List[Answer],
107+
embeddings: List[Any], model: Any, iat_type: Any):
108+
pass
107109

108110
def get_scalar_data(self, res_data, **kwargs) -> CacheData:
109111
return CacheData(question=res_data[0], answers=res_data[1])
@@ -158,9 +160,15 @@ def __init__(
158160
self.v = v
159161
self.o = o
160162

161-
def save(self, question, answer, embedding_data, **kwargs):
163+
# def save(self, question, answer, embedding_data, **kwargs):
164+
# model = kwargs.pop("model", None)
165+
# self.import_data([question], [answer], [embedding_data], model)
166+
167+
def save(self, text, image_url, image_id, answer, embedding, **kwargs):
162168
model = kwargs.pop("model", None)
163-
self.import_data([question], [answer], [embedding_data], model)
169+
mm_type = kwargs.pop("mm_type", None)
170+
self.import_data([text], [image_url], [image_id], [answer],
171+
[embedding], model, mm_type)
164172

165173
def save_query_resp(self, query_resp_dict, **kwargs):
166174
save_query_start_time = time.time()
@@ -190,36 +198,38 @@ def _process_question_data(self, question: Union[str, Question]):
190198

191199
return Question(question)
192200

193-
def import_data(
194-
self, questions: List[Any], answers: List[Answer], embedding_datas: List[Any], model: Any
195-
):
196-
if len(questions) != len(answers) or len(questions) != len(embedding_datas):
201+
def import_data(self, texts: List[Any], image_urls: List[Any], image_ids: List[Any], answers: List[Answer],
202+
embeddings: List[Any], model: Any, iat_type: Any):
203+
if len(texts) != len(answers):
197204
raise ParamError("Make sure that all parameters have the same length")
198205
cache_datas = []
199206

200-
embedding_datas = [
201-
normalize(embedding_data) for embedding_data in embedding_datas
207+
embeddings = [
208+
normalize(text_embedding) for text_embedding in embeddings
202209
]
203210

204-
for i, embedding_data in enumerate(embedding_datas):
211+
# print('embedding_datas: {}'.format(embedding_datas))
212+
for i, embedding in enumerate(embeddings):
205213
if self.o is not None:
206214
ans = self._process_answer_data(answers[i])
207215
else:
208216
ans = answers[i]
209-
210-
question = questions[i]
211-
embedding_data = embedding_data.astype("float32")
212-
cache_datas.append([ans, question, embedding_data, model])
213-
214-
ids = self.s.batch_insert(cache_datas)
215-
logging.info('ids: {}'.format(ids))
216-
self.v.mul_add(
217+
text = texts[i]
218+
image_url = image_urls[i]
219+
image_id = image_ids[i]
220+
# iat_embedding = embedding.astype("float32")
221+
cache_datas.append([ans, text, image_url, image_id, model])
222+
223+
# ids = self.s.batch_multimodal_insert(cache_datas)
224+
ids = self.s.batch_iat_insert(cache_datas)
225+
# self.v.multimodal_add(
226+
self.v.iat_add(
217227
[
218-
VectorData(id=ids[i], data=embedding_data)
219-
for i, embedding_data in enumerate(embedding_datas)
228+
VectorData(id=ids[i], data=embedding)
229+
for i, embedding in enumerate(embeddings)
220230
],
221-
model
222-
231+
model,
232+
iat_type
223233
)
224234

225235
def get_scalar_data(self, res_data, **kwargs) -> Optional[CacheData]:
@@ -256,8 +266,8 @@ def delete(self, id_list, **kwargs):
256266
return {'status': 'success', 'milvus': 'delete_count: '+str(v_delete_count),
257267
'mysql': 'delete_count: '+str(s_delete_count)}
258268

259-
def create_index(self, model, **kwargs):
260-
return self.v.create(model)
269+
def create_index(self, model, mm_type, **kwargs):
270+
return self.v.create(model, mm_type)
261271

262272
def truncate(self, model_name):
263273
# drop vector base data

modelcache/manager_mm/factory.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22
from typing import Union, Callable
3-
from modelcache.manager import CacheBase, VectorBase, ObjectBase
4-
from modelcache.manager.data_manager import SSDataManager, MapDataManager
3+
from modelcache.manager_mm import CacheBase, VectorBase, ObjectBase
4+
from modelcache.manager_mm.data_manager import SSDataManager, MapDataManager
55

66

77
def get_data_manager(
@@ -25,26 +25,3 @@ def get_data_manager(
2525
object_base = ObjectBase(name=object_base)
2626
assert cache_base and vector_base
2727
return SSDataManager(cache_base, vector_base, object_base, max_size, clean_size, eviction)
28-
29-
30-
def get_data_manager_mm(
31-
cache_base: Union[CacheBase, str] = None,
32-
vector_base: Union[VectorBase, str] = None,
33-
object_base: Union[ObjectBase, str] = None,
34-
max_size: int = 1000,
35-
clean_size: int = None,
36-
eviction: str = "LRU",
37-
data_path: str = "data_map.txt",
38-
get_data_container: Callable = None,
39-
):
40-
if not cache_base and not vector_base:
41-
return MapDataManager(data_path, max_size, get_data_container)
42-
43-
if isinstance(cache_base, str):
44-
cache_base = CacheBase(name=cache_base)
45-
if isinstance(vector_base, str):
46-
vector_base = VectorBase(name=vector_base)
47-
if isinstance(object_base, str):
48-
object_base = ObjectBase(name=object_base)
49-
assert cache_base and vector_base
50-
return SSDataManager(cache_base, vector_base, object_base, max_size, clean_size, eviction)

modelcache/manager_mm/scalar_data/sql_storage.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,24 @@ def create(self):
3636

3737
def _insert(self, data: List):
3838
answer = data[0]
39-
question = data[1]
40-
embedding_data = data[2]
41-
model = data[3]
39+
text = data[1]
40+
image_url = data[2]
41+
image_id = data[3]
42+
model = data[4]
4243
answer_type = 0
43-
embedding_data = embedding_data.tobytes()
44-
45-
table_name = "cache_codegpt_answer"
46-
insert_sql = "INSERT INTO {} (question, answer, answer_type, model, embedding_data) VALUES (%s, %s, %s, %s, _binary%s)".format(table_name)
4744

45+
table_name = "multimodal_answer"
46+
insert_sql = "INSERT INTO {} (question_text, image_url, image_id, answer, answer_type, model) VALUES (%s, %s, %s, %s, %s, %s)".format(table_name)
4847
conn = self.pool.connection()
4948
try:
5049
with conn.cursor() as cursor:
51-
# 执行插入数据操作
52-
values = (question, answer, answer_type, model, embedding_data)
50+
# data insert operation
51+
values = (text, image_url, image_id, answer, answer_type, model)
5352
cursor.execute(insert_sql, values)
5453
conn.commit()
5554
id = cursor.lastrowid
5655
finally:
57-
# 关闭连接,将连接返回给连接池
56+
# Close the connection and return it back to the connection pool
5857
conn.close()
5958
return id
6059

modelcache/manager_mm/vector_data/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def delete(self, ids) -> bool:
3131
pass
3232

3333
@abstractmethod
34-
def rebuild_col(self, model):
34+
def rebuild_idx(self, model):
3535
pass
3636

3737
def flush(self):

0 commit comments

Comments
 (0)