2626 from chromadb .config import Settings
2727 from chromadb .utils import embedding_functions
2828except ImportError :
29- chromadb = None
30-
31- if chromadb is None :
3229 raise ImportError (
3330 "The chromadb library is required to use ChromadbRM. Install it with `pip install dspy-ai[chromadb]`" ,
3431 )
@@ -73,9 +70,10 @@ def __init__(
7370 embedding_function : Optional [
7471 EmbeddingFunction [Embeddable ]
7572 ] = ef .DefaultEmbeddingFunction (),
73+ client : Optional [chromadb .Client ] = None ,
7674 k : int = 7 ,
7775 ):
78- self ._init_chromadb (collection_name , persist_directory )
76+ self ._init_chromadb (collection_name , persist_directory , client = client )
7977 self .ef = embedding_function
8078
8179 super ().__init__ (k = k )
@@ -84,22 +82,26 @@ def _init_chromadb(
8482 self ,
8583 collection_name : str ,
8684 persist_directory : str ,
85+ client : Optional [chromadb .Client ] = None
8786 ) -> chromadb .Collection :
8887 """Initialize chromadb and return the loaded index.
8988
9089 Args:
9190 collection_name (str): chromadb collection name
9291 persist_directory (str): chromadb persist directory
92+ client (chromadb.Client): A chromadb client provided by user
9393
94-
95- Returns:
94+ Returns: collection per collection_name
9695 """
9796
98- self ._chromadb_client = chromadb .Client (
99- Settings (
100- persist_directory = persist_directory ,
101- is_persistent = True ,
102- ),
97+ if client :
98+ self ._chromadb_client = client
99+ else :
100+ self ._chromadb_client = chromadb .Client (
101+ Settings (
102+ persist_directory = persist_directory ,
103+ is_persistent = True ,
104+ ),
103105 )
104106 self ._chromadb_collection = self ._chromadb_client .get_or_create_collection (
105107 name = collection_name ,
0 commit comments