Skip to content

Commit b474d15

Browse files
Changed the embedding model to be mpnet-base-v2
added flags for its usage and updated milvus to use cosine instead of L2 for searching Co-authored-by: olgaoznovich <ol.oznovich@gmail.com> Co-authored-by: Yuval-Roth <rothyuv@post.bgu.ac.il>
1 parent 5127d77 commit b474d15

File tree

6 files changed

+137
-73
lines changed

6 files changed

+137
-73
lines changed

flask4modelcache.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import json
77
from modelcache import cache
88
from modelcache.adapter import adapter
9+
from modelcache.embedding.mpnet_base import MPNet_Base
10+
from modelcache.manager.vector_data import manager
911
from modelcache.manager import CacheBase, VectorBase, get_data_manager, data_manager
1012
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
1113
from modelcache.processor.pre import query_multi_splicing
@@ -30,9 +32,17 @@ def save_query_info(result, model, query, delta_time_log):
3032
def response_hitquery(cache_resp):
3133
return cache_resp['hitQuery']
3234

33-
data2vec = Data2VecAudio()
34-
embedding_func = data2vec.to_embeddings
35-
dimension = data2vec.dimension
35+
manager.MPNet_base = True
36+
37+
if manager.MPNet_base:
38+
mpnet_base = MPNet_Base()
39+
embedding_func = lambda x: mpnet_base.embedding_func(x)
40+
dimension = mpnet_base.dimension
41+
data_manager.NORMALIZE = False
42+
else:
43+
data2vec = Data2VecAudio()
44+
embedding_func = data2vec.to_embeddings
45+
dimension = data2vec.dimension
3646

3747
mysql_config = configparser.ConfigParser()
3848
mysql_config.read('modelcache/config/mysql_config.ini')
@@ -49,8 +59,30 @@ def response_hitquery(cache_resp):
4959
# chromadb_config = configparser.ConfigParser()
5060
# chromadb_config.read('modelcache/config/chromadb_config.ini')
5161

52-
data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
53-
VectorBase("milvus", dimension=dimension, milvus_config=milvus_config))
62+
data_manager = get_data_manager(
63+
CacheBase("mysql", config=mysql_config),
64+
VectorBase("milvus",
65+
dimension=dimension,
66+
milvus_config=milvus_config,
67+
index_params={
68+
"metric_type": "COSINE",
69+
"index_type": "HNSW",
70+
"params": {"M": 16, "efConstruction": 64},
71+
} if manager.MPNet_base else None,
72+
search_params={
73+
"IVF_FLAT": {"metric_type": "COSINE", "params": {"nprobe": 10}},
74+
"IVF_SQ8": {"metric_type": "COSINE", "params": {"nprobe": 10}},
75+
"IVF_PQ": {"metric_type": "COSINE", "params": {"nprobe": 10}},
76+
"HNSW": {"metric_type": "COSINE", "params": {"ef": 10}},
77+
"RHNSW_FLAT": {"metric_type": "COSINE", "params": {"ef": 10}},
78+
"RHNSW_SQ": {"metric_type": "COSINE", "params": {"ef": 10}},
79+
"RHNSW_PQ": {"metric_type": "COSINE", "params": {"ef": 10}},
80+
"IVF_HNSW": {"metric_type": "COSINE", "params": {"nprobe": 10, "ef": 10}},
81+
"ANNOY": {"metric_type": "COSINE", "params": {"search_k": 10}},
82+
"AUTOINDEX": {"metric_type": "COSINE", "params": {}},
83+
} if manager.MPNet_base else None
84+
)
85+
)
5486

5587

5688
# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),

model/download_bert_embedder.bat

Lines changed: 0 additions & 1 deletion
This file was deleted.

modelcache/adapter/adapter_query.py

Lines changed: 80 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# -*- coding: utf-8 -*-
22
import logging
33
import time
4+
45
from modelcache import cache
56
from modelcache.utils.error import NotInitError
67
from modelcache.utils.time import time_cal
78
from modelcache.processor.pre import multi_analysis
89
from FlagEmbedding import FlagReranker
10+
from modelcache.manager.vector_data import manager
911

1012
USE_RERANKER = False # 如果为 True 则启用 reranker,否则使用原有逻辑
1113

@@ -44,39 +46,47 @@ def adapt_query(cache_data_convert, *args, **kwargs):
4446
cache_answers = []
4547
cache_questions = []
4648
cache_ids = []
47-
similarity_threshold = chat_cache.config.similarity_threshold
48-
similarity_threshold_long = chat_cache.config.similarity_threshold_long
49+
cosine_similarity = cache_data_list[0][0]
4950

50-
min_rank, max_rank = chat_cache.similarity_evaluation.range()
51-
rank_threshold = (max_rank - min_rank) * similarity_threshold * cache_factor
52-
rank_threshold_long = (max_rank - min_rank) * similarity_threshold_long * cache_factor
53-
rank_threshold = (
54-
max_rank
55-
if rank_threshold > max_rank
56-
else min_rank
57-
if rank_threshold < min_rank
58-
else rank_threshold
59-
)
60-
rank_threshold_long = (
61-
max_rank
62-
if rank_threshold_long > max_rank
63-
else min_rank
64-
if rank_threshold_long < min_rank
65-
else rank_threshold_long
66-
)
67-
if cache_data_list is None or len(cache_data_list) == 0:
68-
rank_pre = -1.0
51+
if manager.MPNet_base:
52+
# This code uses the built-in cosine similarity evaluation in milvus
53+
if cosine_similarity < 0.9:
54+
return None
6955
else:
70-
cache_data_dict = {'search_result': cache_data_list[0]}
71-
rank_pre = chat_cache.similarity_evaluation.evaluation(
72-
None,
73-
cache_data_dict,
74-
extra_param=context.get("evaluation_func", None),
56+
## this is the code that uses L2 for similarity evaluation
57+
similarity_threshold = chat_cache.config.similarity_threshold
58+
similarity_threshold_long = chat_cache.config.similarity_threshold_long
59+
60+
min_rank, max_rank = chat_cache.similarity_evaluation.range()
61+
rank_threshold = (max_rank - min_rank) * similarity_threshold * cache_factor
62+
rank_threshold_long = (max_rank - min_rank) * similarity_threshold_long * cache_factor
63+
rank_threshold = (
64+
max_rank
65+
if rank_threshold > max_rank
66+
else min_rank
67+
if rank_threshold < min_rank
68+
else rank_threshold
69+
)
70+
rank_threshold_long = (
71+
max_rank
72+
if rank_threshold_long > max_rank
73+
else min_rank
74+
if rank_threshold_long < min_rank
75+
else rank_threshold_long
7576
)
76-
if rank_pre < rank_threshold:
77-
return None
77+
if cache_data_list is None or len(cache_data_list) == 0:
78+
rank_pre = -1.0
79+
else:
80+
cache_data_dict = {'search_result': cache_data_list[0]}
81+
rank_pre = chat_cache.similarity_evaluation.evaluation(
82+
None,
83+
cache_data_dict,
84+
extra_param=context.get("evaluation_func", None),
85+
)
86+
if rank_pre < rank_threshold:
87+
return None
7888

79-
if USE_RERANKER:
89+
if USE_RERANKER and not manager.MPNet_base:
8090
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=False)
8191
for cache_data in cache_data_list:
8292
primary_id = cache_data[1]
@@ -132,45 +142,50 @@ def adapt_query(cache_data_convert, *args, **kwargs):
132142
if ret is None:
133143
continue
134144

135-
if "deps" in context and hasattr(ret.question, "deps"):
136-
eval_query_data = {
137-
"question": context["deps"][0]["data"],
138-
"embedding": None
139-
}
140-
eval_cache_data = {
141-
"question": ret.question.deps[0].data,
142-
"answer": ret.answers[0].answer,
143-
"search_result": cache_data,
144-
"embedding": None,
145-
}
145+
if manager.MPNet_base:
146+
cache_answers.append((cosine_similarity, ret[1]))
147+
cache_questions.append((cosine_similarity, ret[0]))
148+
cache_ids.append((cosine_similarity, primary_id))
146149
else:
147-
eval_query_data = {
148-
"question": pre_embedding_data,
149-
"embedding": embedding_data,
150-
}
150+
if "deps" in context and hasattr(ret.question, "deps"):
151+
eval_query_data = {
152+
"question": context["deps"][0]["data"],
153+
"embedding": None
154+
}
155+
eval_cache_data = {
156+
"question": ret.question.deps[0].data,
157+
"answer": ret.answers[0].answer,
158+
"search_result": cache_data,
159+
"embedding": None,
160+
}
161+
else:
162+
eval_query_data = {
163+
"question": pre_embedding_data,
164+
"embedding": embedding_data,
165+
}
151166

152-
eval_cache_data = {
153-
"question": ret[0],
154-
"answer": ret[1],
155-
"search_result": cache_data,
156-
"embedding": None
157-
}
158-
rank = chat_cache.similarity_evaluation.evaluation(
159-
eval_query_data,
160-
eval_cache_data,
161-
extra_param=context.get("evaluation_func", None),
162-
)
167+
eval_cache_data = {
168+
"question": ret[0],
169+
"answer": ret[1],
170+
"search_result": cache_data,
171+
"embedding": None
172+
}
173+
rank = chat_cache.similarity_evaluation.evaluation(
174+
eval_query_data,
175+
eval_cache_data,
176+
extra_param=context.get("evaluation_func", None),
177+
)
163178

164-
if len(pre_embedding_data) <= 256:
165-
if rank_threshold <= rank:
166-
cache_answers.append((rank, ret[1]))
167-
cache_questions.append((rank, ret[0]))
168-
cache_ids.append((rank, primary_id))
169-
else:
170-
if rank_threshold_long <= rank:
171-
cache_answers.append((rank, ret[1]))
172-
cache_questions.append((rank, ret[0]))
173-
cache_ids.append((rank, primary_id))
179+
if len(pre_embedding_data) <= 256:
180+
if rank_threshold <= rank:
181+
cache_answers.append((rank, ret[1]))
182+
cache_questions.append((rank, ret[0]))
183+
cache_ids.append((rank, primary_id))
184+
else:
185+
if rank_threshold_long <= rank:
186+
cache_answers.append((rank, ret[1]))
187+
cache_questions.append((rank, ret[0]))
188+
cache_ids.append((rank, primary_id))
174189

175190
cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True)
176191
cache_questions = sorted(cache_questions, key=lambda x: x[0], reverse=True)

modelcache/embedding/mpnet_base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from sentence_transformers import SentenceTransformer
2+
3+
class MPNet_Base:
4+
def __init__(self):
5+
self.dimension = 768
6+
self.model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
7+
8+
def embedding_func(self, *args, **kwargs):
9+
if not args:
10+
raise ValueError("No word provided for embedding.")
11+
embeddings = self.model.encode(args)
12+
return embeddings[0] if len(args) == 1 else embeddings
13+
14+
def similarity(self, a, b):
15+
if not a or not b:
16+
raise ValueError("Both inputs must be non-empty for similarity calculation.")
17+
return self.model.similarity(a, b)

modelcache/manager/vector_data/manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
COLLECTION_NAME = "modelcache"
1919

20+
MPNet_base = False # whether to use MPNet base model for embedding, if True, will use cosine similarity evaluation in milvus
21+
2022

2123
class VectorBase:
2224
"""

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,4 @@ elasticsearch==7.10.0
1919
snowflake-id==1.0.2
2020
flagembedding==1.3.4
2121
cryptography==45.0.2
22-
mediapipe==0.10.21
23-
protobuf==4.25.8
22+
sentence-transformers==4.1.0

0 commit comments

Comments
 (0)