11import asyncio
22import json
3+ import sqlite3
34import uuid
45from pathlib import Path
56from typing import Dict , List , Optional , Type
67
8+ import numpy as np
9+ import sqlite_vec_sl_tmp
710import structlog
811from alembic import command as alembic_command
912from alembic .config import Config as AlembicConfig
2225 IntermediatePromptWithOutputUsageAlerts ,
2326 MuxRule ,
2427 Output ,
28+ Persona ,
29+ PersonaDistance ,
30+ PersonaEmbedding ,
2531 Prompt ,
2632 ProviderAuthMaterial ,
2733 ProviderEndpoint ,
@@ -65,7 +71,7 @@ def __new__(cls, *args, **kwargs):
6571 # It should only be used for testing
6672 if "_no_singleton" in kwargs and kwargs ["_no_singleton" ]:
6773 kwargs .pop ("_no_singleton" )
68- return super ().__new__ (cls , * args , ** kwargs )
74+ return super ().__new__ (cls )
6975
7076 if cls ._instance is None :
7177 cls ._instance = super ().__new__ (cls )
@@ -92,6 +98,22 @@ def __init__(self, sqlite_path: Optional[str] = None, **kwargs):
9298 }
9399 self ._async_db_engine = create_async_engine (** engine_dict )
94100
101+ def _get_vec_db_connection (self ):
102+ """
103+ Vector database connection is a separate connection to the SQLite database. aiosqlite
104+ does not support loading extensions, so we need to use the sqlite3 module to load the
105+ vector extension.
106+ """
107+ try :
108+ conn = sqlite3 .connect (self ._db_path )
109+ conn .enable_load_extension (True )
110+ sqlite_vec_sl_tmp .load (conn )
111+ conn .enable_load_extension (False )
112+ return conn
113+ except Exception :
114+ logger .exception ("Failed to initialize vector database connection" )
115+ raise
116+
95117 def does_db_exist (self ):
96118 return self ._db_path .is_file ()
97119
@@ -523,6 +545,30 @@ async def add_mux(self, mux: MuxRule) -> MuxRule:
523545 added_mux = await self ._execute_update_pydantic_model (mux , sql , should_raise = True )
524546 return added_mux
525547
548+ async def add_persona (self , persona : PersonaEmbedding ) -> None :
549+ """Add a new Persona to the DB.
550+
551+ This handles validation and insertion of a new persona.
552+
553+ It may raise a AlreadyExistsError if the persona already exists.
554+ """
555+ sql = text (
556+ """
557+ INSERT INTO personas (id, name, description, description_embedding)
558+ VALUES (:id, :name, :description, :description_embedding)
559+ """
560+ )
561+
562+ try :
563+ # For Pydantic we convert the numpy array to string when serializing with .model_dumpy()
564+ # We need to convert it back to a numpy array before inserting it into the DB.
565+ persona_dict = persona .model_dump ()
566+ persona_dict ["description_embedding" ] = persona .description_embedding
567+ await self ._execute_with_no_return (sql , persona_dict )
568+ except IntegrityError as e :
569+ logger .debug (f"Exception type: { type (e )} " )
570+ raise AlreadyExistsError (f"Persona '{ persona .name } ' already exists." )
571+
526572
527573class DbReader (DbCodeGate ):
528574 def __init__ (self , sqlite_path : Optional [str ] = None , * args , ** kwargs ):
@@ -569,6 +615,20 @@ async def _exec_select_conditions_to_pydantic(
569615 raise e
570616 return None
571617
618+ async def _exec_vec_db_query_to_pydantic (
619+ self , sql_command : str , conditions : dict , model_type : Type [BaseModel ]
620+ ) -> List [BaseModel ]:
621+ """
622+ Execute a query on the vector database. This is a separate connection to the SQLite
623+ database that has the vector extension loaded.
624+ """
625+ conn = self ._get_vec_db_connection ()
626+ conn .row_factory = sqlite3 .Row
627+ cursor = conn .cursor ()
628+ results = [model_type (** row ) for row in cursor .execute (sql_command , conditions )]
629+ conn .close ()
630+ return results
631+
572632 async def get_prompts_with_output (self , workpace_id : str ) -> List [GetPromptWithOutputsRow ]:
573633 sql = text (
574634 """
@@ -893,6 +953,45 @@ async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]:
893953 )
894954 return muxes
895955
956+ async def get_persona_by_name (self , persona_name : str ) -> Optional [Persona ]:
957+ """
958+ Get a persona by name.
959+ """
960+ sql = text (
961+ """
962+ SELECT
963+ id, name, description
964+ FROM personas
965+ WHERE name = :name
966+ """
967+ )
968+ conditions = {"name" : persona_name }
969+ personas = await self ._exec_select_conditions_to_pydantic (
970+ Persona , sql , conditions , should_raise = True
971+ )
972+ return personas [0 ] if personas else None
973+
974+ async def get_distance_to_persona (
975+ self , persona_id : str , query_embedding : np .ndarray
976+ ) -> PersonaDistance :
977+ """
978+ Get the distance between a persona and a query embedding.
979+ """
980+ sql = """
981+ SELECT
982+ id,
983+ name,
984+ description,
985+ vec_distance_cosine(description_embedding, :query_embedding) as distance
986+ FROM personas
987+ WHERE id = :id
988+ """
989+ conditions = {"id" : persona_id , "query_embedding" : query_embedding }
990+ persona_distance = await self ._exec_vec_db_query_to_pydantic (
991+ sql , conditions , PersonaDistance
992+ )
993+ return persona_distance [0 ]
994+
896995
897996def init_db_sync (db_path : Optional [str ] = None ):
898997 """DB will be initialized in the constructor in case it doesn't exist."""
0 commit comments