@@ -43,6 +43,7 @@ def __init__(self):
4343 self ._n_gpu = conf .chat_model_n_gpu_layers
4444 self ._persona_threshold = conf .persona_threshold
4545 self ._persona_diff_desc_threshold = conf .persona_diff_desc_threshold
46+ self ._distances_weight_factor = conf .distances_weight_factor
4647 self ._db_recorder = DbRecorder ()
4748 self ._db_reader = DbReader ()
4849
@@ -99,18 +100,17 @@ def _clean_text_for_embedding(self, text: str) -> str:
99100
100101 return text
101102
102- async def _embed_text (self , text : str ) -> np .ndarray :
103+ async def _embed_texts (self , texts : List [ str ] ) -> np .ndarray :
103104 """
104105 Helper function to embed text using the inference engine.
105106 """
106- cleaned_text = self ._clean_text_for_embedding (text )
107+ cleaned_texts = [ self ._clean_text_for_embedding (text ) for text in texts ]
107108 # .embed returns a list of embeddings
108109 embed_list = await self ._inference_engine .embed (
109- self ._embeddings_model , [ cleaned_text ] , n_gpu_layers = self ._n_gpu
110+ self ._embeddings_model , cleaned_texts , n_gpu_layers = self ._n_gpu
110111 )
111- # Use only the first entry in the list and make sure we have the appropriate type
112- logger .debug ("Text embedded in semantic routing" , text = cleaned_text [:50 ])
113- return np .array (embed_list [0 ], dtype = np .float32 )
112+ logger .debug ("Text embedded in semantic routing" , num_texts = len (texts ))
113+ return np .array (embed_list , dtype = np .float32 )
114114
115115 async def _is_persona_description_diff (
116116 self , emb_persona_desc : np .ndarray , exclude_id : Optional [str ]
@@ -142,7 +142,8 @@ async def _validate_persona_description(
142142 Validate the persona description by embedding the text and checking if it is
143143 different enough from existing personas.
144144 """
145- emb_persona_desc = await self ._embed_text (persona_desc )
145+ emb_persona_desc_list = await self ._embed_texts ([persona_desc ])
146+ emb_persona_desc = emb_persona_desc_list [0 ]
146147 if not await self ._is_persona_description_diff (emb_persona_desc , exclude_id ):
147148 raise PersonaSimilarDescriptionError (
148149 "The persona description is too similar to existing personas."
@@ -217,21 +218,87 @@ async def delete_persona(self, persona_name: str) -> None:
217218 await self ._db_recorder .delete_persona (persona .id )
218219 logger .info (f"Deleted persona { persona_name } from the database." )
219220
220- async def check_persona_match (self , persona_name : str , query : str ) -> bool :
221+ async def _get_cosine_distance (self , emb_queries : np . ndarray , emb_persona : np . ndarray ) -> float :
221222 """
222- Check if the query matches the persona description. A vector similarity
223- search is performed between the query and the persona description.
223+ Calculate the cosine distance between the queries embeddings and persona embedding.
224+ Persona embedding is a single vector of length M
225+ Queries embeddings is a matrix of shape (N, M)
226+ N is the number of queries. User messages in this case.
227+ M is the number of dimensions in the embedding
228+
229+ Defintion of cosine distance: 1 - cosine similarity
230+ [Cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity)
231+
232+ NOTE: Experimented by individually querying SQLite for each query, but as the number
233+ of queries increases, the performance is better with NumPy. If the number of queries
234+ is small the performance is onpar. Hence the decision to use NumPy.
235+ """
236+ # Handle the case where we have a single query (single user message)
237+ if emb_queries .ndim == 1 :
238+ emb_queries = emb_queries .reshape (1 , - 1 )
239+
240+ emb_queries_norm = np .linalg .norm (emb_queries , axis = 1 )
241+ persona_embed_norm = np .linalg .norm (emb_persona )
242+ cosine_similarities = np .dot (emb_queries , emb_persona .T ) / (
243+ emb_queries_norm * persona_embed_norm
244+ )
245+ # We could also use directly cosine_similarities but we get the distance to match
246+ # the behavior of SQLite function vec_distance_cosine
247+ cosine_distances = 1 - cosine_similarities
248+ return cosine_distances
249+
250+ async def _weight_distances (self , distances : np .ndarray ) -> np .ndarray :
251+ """
252+ Weights the received distances, with later positions being more important and the
253+ last position unchanged. The reasoning is that the distances correspond to user
254+ messages, with the last message being the most recent and therefore the most
255+ important.
256+
257+ Args:
258+ distances: NumPy array of float values between 0 and 2
259+ weight_factor: Factor that determines how quickly weights increase (0-1)
260+ Lower values create a steeper importance curve. 1 makes
261+ all weights equal.
262+
263+ Returns:
264+ Weighted distances as a NumPy array
265+ """
266+ # Get array length
267+ n = len (distances )
268+
269+ # Create positions array in reverse order (n-1, n-2, ..., 1, 0)
270+ # This makes the last element have position 0
271+ positions = np .arange (n - 1 , - 1 , - 1 )
272+
273+ # Create weights - now the last element (position 0) gets weight 1
274+ weights = self ._distances_weight_factor ** positions
275+
276+ # Apply weights by dividing distances
277+ # Smaller weight -> larger effective distance
278+ weighted_distances = distances / weights
279+ return weighted_distances
280+
281+ async def check_persona_match (self , persona_name : str , queries : List [str ]) -> bool :
282+ """
283+ Check if the queries match the persona description. A vector similarity
284+ search is performed between the queries and the persona description.
224285 0 means the vectors are identical, 2 means they are orthogonal.
225- See
226- [sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine)
286+
287+ The vectors are compared using cosine similarity implemented in _get_cosine_distance.
227288 """
228- persona = await self ._db_reader .get_persona_by_name (persona_name )
229- if not persona :
289+ persona_embed = await self ._db_reader .get_persona_embed_by_name (persona_name )
290+ if not persona_embed :
230291 raise PersonaDoesNotExistError (f"Persona { persona_name } does not exist." )
231292
232- emb_query = await self ._embed_text (query )
233- persona_distance = await self ._db_reader .get_distance_to_persona (persona .id , emb_query )
234- logger .info (f"Persona distance to { persona_name } " , distance = persona_distance .distance )
235- if persona_distance .distance < self ._persona_threshold :
293+ emb_queries = await self ._embed_texts (queries )
294+ cosine_distances = await self ._get_cosine_distance (
295+ emb_queries , persona_embed .description_embedding
296+ )
297+ logger .debug ("Cosine distances calculated" , cosine_distances = cosine_distances )
298+
299+ weighted_distances = await self ._weight_distances (cosine_distances )
300+ logger .info ("Weighted distances to persona" , weighted_distances = weighted_distances )
301+
302+ if np .any (weighted_distances < self ._persona_threshold ):
236303 return True
237304 return False
0 commit comments