|
25 | 25 | class DataManager(metaclass=ABCMeta): |
26 | 26 | """DataManager manage the cache data, including save and search""" |
27 | 27 |
|
| 28 | + # @abstractmethod |
| 29 | + # def save(self, question, answer, embedding_data, **kwargs): |
| 30 | + # pass |
| 31 | + |
28 | 32 | @abstractmethod |
29 | | - def save(self, question, answer, embedding_data, **kwargs): |
| 33 | + def save(self, text, image_url, image_id, answer, embedding, **kwargs): |
30 | 34 | pass |
31 | 35 |
|
32 | 36 | @abstractmethod |
33 | 37 | def save_query_resp(self, query_resp_dict, **kwargs): |
34 | 38 | pass |
35 | 39 |
|
36 | 40 | @abstractmethod |
37 | | - def import_data( |
38 | | - self, questions: List[Any], answers: List[Any], embedding_datas: List[Any], model:Any |
39 | | - ): |
| 41 | + def import_data(self, texts: List[Any], image_urls: List[Any], image_ids: List[Any], answers: List[Answer], |
| 42 | + embeddings: List[Any], model: Any, iat_type: Any): |
40 | 43 | pass |
41 | 44 |
|
42 | 45 | @abstractmethod |
@@ -89,21 +92,20 @@ def init(self): |
89 | 92 | f"You don't have permission to access this file <{self.data_path}>." |
90 | 93 | ) |
91 | 94 |
|
92 | | - def save(self, question, answer, embedding_data, **kwargs): |
93 | | - if isinstance(question, Question): |
94 | | - question = question.content |
95 | | - self.data[embedding_data] = (question, answer, embedding_data) |
| 95 | + # def save(self, question, answer, embedding_data, **kwargs): |
| 96 | + # if isinstance(question, Question): |
| 97 | + # question = question.content |
| 98 | + # self.data[embedding_data] = (question, answer, embedding_data) |
| 99 | + |
| 100 | + def save(self, text, image_url, image_id, answer, embedding, **kwargs): |
| 101 | + pass |
96 | 102 |
|
97 | 103 | def save_query_resp(self, query_resp_dict, **kwargs): |
98 | 104 | pass |
99 | 105 |
|
100 | | - def import_data( |
101 | | - self, questions: List[Any], answers: List[Any], embedding_datas: List[Any], model: Any |
102 | | - ): |
103 | | - if len(questions) != len(answers) or len(questions) != len(embedding_datas): |
104 | | - raise ParamError("Make sure that all parameters have the same length") |
105 | | - for i, embedding_data in enumerate(embedding_datas): |
106 | | - self.data[embedding_data] = (questions[i], answers[i], embedding_datas[i]) |
| 106 | + def import_data(self, texts: List[Any], image_urls: List[Any], image_ids: List[Any], answers: List[Answer], |
| 107 | + embeddings: List[Any], model: Any, iat_type: Any): |
| 108 | + pass |
107 | 109 |
|
108 | 110 | def get_scalar_data(self, res_data, **kwargs) -> CacheData: |
109 | 111 | return CacheData(question=res_data[0], answers=res_data[1]) |
@@ -158,9 +160,15 @@ def __init__( |
158 | 160 | self.v = v |
159 | 161 | self.o = o |
160 | 162 |
|
161 | | - def save(self, question, answer, embedding_data, **kwargs): |
| 163 | + # def save(self, question, answer, embedding_data, **kwargs): |
| 164 | + # model = kwargs.pop("model", None) |
| 165 | + # self.import_data([question], [answer], [embedding_data], model) |
| 166 | + |
| 167 | + def save(self, text, image_url, image_id, answer, embedding, **kwargs): |
162 | 168 | model = kwargs.pop("model", None) |
163 | | - self.import_data([question], [answer], [embedding_data], model) |
| 169 | + mm_type = kwargs.pop("mm_type", None) |
| 170 | + self.import_data([text], [image_url], [image_id], [answer], |
| 171 | + [embedding], model, mm_type) |
164 | 172 |
|
165 | 173 | def save_query_resp(self, query_resp_dict, **kwargs): |
166 | 174 | save_query_start_time = time.time() |
@@ -190,36 +198,38 @@ def _process_question_data(self, question: Union[str, Question]): |
190 | 198 |
|
191 | 199 | return Question(question) |
192 | 200 |
|
193 | | - def import_data( |
194 | | - self, questions: List[Any], answers: List[Answer], embedding_datas: List[Any], model: Any |
195 | | - ): |
196 | | - if len(questions) != len(answers) or len(questions) != len(embedding_datas): |
| 201 | + def import_data(self, texts: List[Any], image_urls: List[Any], image_ids: List[Any], answers: List[Answer], |
| 202 | + embeddings: List[Any], model: Any, iat_type: Any): |
| 203 | + if len(texts) != len(answers): |
197 | 204 | raise ParamError("Make sure that all parameters have the same length") |
198 | 205 | cache_datas = [] |
199 | 206 |
|
200 | | - embedding_datas = [ |
201 | | - normalize(embedding_data) for embedding_data in embedding_datas |
| 207 | + embeddings = [ |
| 208 | + normalize(text_embedding) for text_embedding in embeddings |
202 | 209 | ] |
203 | 210 |
|
204 | | - for i, embedding_data in enumerate(embedding_datas): |
| 211 | + # print('embedding_datas: {}'.format(embedding_datas)) |
| 212 | + for i, embedding in enumerate(embeddings): |
205 | 213 | if self.o is not None: |
206 | 214 | ans = self._process_answer_data(answers[i]) |
207 | 215 | else: |
208 | 216 | ans = answers[i] |
209 | | - |
210 | | - question = questions[i] |
211 | | - embedding_data = embedding_data.astype("float32") |
212 | | - cache_datas.append([ans, question, embedding_data, model]) |
213 | | - |
214 | | - ids = self.s.batch_insert(cache_datas) |
215 | | - logging.info('ids: {}'.format(ids)) |
216 | | - self.v.mul_add( |
| 217 | + text = texts[i] |
| 218 | + image_url = image_urls[i] |
| 219 | + image_id = image_ids[i] |
| 220 | + # iat_embedding = embedding.astype("float32") |
| 221 | + cache_datas.append([ans, text, image_url, image_id, model]) |
| 222 | + |
| 223 | + # ids = self.s.batch_multimodal_insert(cache_datas) |
| 224 | + ids = self.s.batch_iat_insert(cache_datas) |
| 225 | + # self.v.multimodal_add( |
| 226 | + self.v.iat_add( |
217 | 227 | [ |
218 | | - VectorData(id=ids[i], data=embedding_data) |
219 | | - for i, embedding_data in enumerate(embedding_datas) |
| 228 | + VectorData(id=ids[i], data=embedding) |
| 229 | + for i, embedding in enumerate(embeddings) |
220 | 230 | ], |
221 | | - model |
222 | | - |
| 231 | + model, |
| 232 | + iat_type |
223 | 233 | ) |
224 | 234 |
|
225 | 235 | def get_scalar_data(self, res_data, **kwargs) -> Optional[CacheData]: |
@@ -256,8 +266,8 @@ def delete(self, id_list, **kwargs): |
256 | 266 | return {'status': 'success', 'milvus': 'delete_count: '+str(v_delete_count), |
257 | 267 | 'mysql': 'delete_count: '+str(s_delete_count)} |
258 | 268 |
|
259 | | - def create_index(self, model, **kwargs): |
260 | | - return self.v.create(model) |
| 269 | + def create_index(self, model, mm_type, **kwargs): |
| 270 | + return self.v.create(model, mm_type) |
261 | 271 |
|
262 | 272 | def truncate(self, model_name): |
263 | 273 | # drop vector base data |
|
0 commit comments