2626 from chromadb .config import Settings
2727 from chromadb .utils import embedding_functions
2828except ImportError :
29- chromadb = None
30-
31- if chromadb is None :
3229 raise ImportError (
3330 "The chromadb library is required to use ChromadbRM. Install it with `pip install dspy-ai[chromadb]`" ,
3431 )
@@ -46,6 +43,7 @@ class ChromadbRM(dspy.Retrieve):
4643 persist_directory (str): chromadb persist directory
4744 embedding_function (Optional[EmbeddingFunction[Embeddable]]): Optional function to use to embed documents. Defaults to DefaultEmbeddingFunction.
4845 k (int, optional): The number of top passages to retrieve. Defaults to 7.
46+ client(Optional[chromadb.Client]): Optional chromadb client provided by user, default to None
4947
5048 Returns:
5149 dspy.Prediction: An object containing the retrieved passages.
@@ -54,12 +52,25 @@ class ChromadbRM(dspy.Retrieve):
5452 Below is a code snippet that shows how to use this as the default retriever:
5553 ```python
5654 llm = dspy.OpenAI(model="gpt-3.5-turbo")
55+ # using default chromadb client
5756 retriever_model = ChromadbRM('collection_name', 'db_path')
5857 dspy.settings.configure(lm=llm, rm=retriever_model)
5958 # to test the retriever with "my query"
6059 retriever_model("my query")
6160 ```
6261
62+ Use provided chromadb client
63+ ```python
64+ import chromadb
65+ llm = dspy.OpenAI(model="gpt-3.5-turbo")
66+ # say you have a chromadb running on a different port
67+ client = chromadb.HttpClient(host='localhost', port=8889)
68+ retriever_model = ChromadbRM('collection_name', 'db_path', client=client)
69+ dspy.settings.configure(lm=llm, rm=retriever_model)
70+ # to test the retriever with "my query"
71+ retriever_model("my query")
72+ ```
73+
6374 Below is a code snippet that shows how to use this in the forward() function of a module
6475 ```python
6576 self.retrieve = ChromadbRM('collection_name', 'db_path', k=num_passages)
@@ -73,9 +84,10 @@ def __init__(
7384 embedding_function : Optional [
7485 EmbeddingFunction [Embeddable ]
7586 ] = ef .DefaultEmbeddingFunction (),
87+ client : Optional [chromadb .Client ] = None ,
7688 k : int = 7 ,
7789 ):
78- self ._init_chromadb (collection_name , persist_directory )
90+ self ._init_chromadb (collection_name , persist_directory , client = client )
7991 self .ef = embedding_function
8092
8193 super ().__init__ (k = k )
@@ -84,22 +96,26 @@ def _init_chromadb(
8496 self ,
8597 collection_name : str ,
8698 persist_directory : str ,
99+ client : Optional [chromadb .Client ] = None ,
87100 ) -> chromadb .Collection :
88101 """Initialize chromadb and return the loaded index.
89102
90103 Args:
91104 collection_name (str): chromadb collection name
92105 persist_directory (str): chromadb persist directory
106+ client (chromadb.Client): chromadb client provided by user
93107
94-
95- Returns:
108+ Returns: collection per collection_name
96109 """
97110
98- self ._chromadb_client = chromadb .Client (
99- Settings (
100- persist_directory = persist_directory ,
101- is_persistent = True ,
102- ),
111+ if client :
112+ self ._chromadb_client = client
113+ else :
114+ self ._chromadb_client = chromadb .Client (
115+ Settings (
116+ persist_directory = persist_directory ,
117+ is_persistent = True ,
118+ ),
103119 )
104120 self ._chromadb_collection = self ._chromadb_client .get_or_create_collection (
105121 name = collection_name ,
0 commit comments