@@ -43,6 +43,7 @@ class ChromadbRM(dspy.Retrieve):
4343 persist_directory (str): chromadb persist directory
4444 embedding_function (Optional[EmbeddingFunction[Embeddable]]): Optional function to use to embed documents. Defaults to DefaultEmbeddingFunction.
4545 k (int, optional): The number of top passages to retrieve. Defaults to 7.
46+ client(Optional[chromadb.Client]): Optional chromadb client provided by user, default to None
4647
4748 Returns:
4849 dspy.Prediction: An object containing the retrieved passages.
@@ -51,12 +52,25 @@ class ChromadbRM(dspy.Retrieve):
5152 Below is a code snippet that shows how to use this as the default retriever:
5253 ```python
5354 llm = dspy.OpenAI(model="gpt-3.5-turbo")
55+ # using default chromadb client
5456 retriever_model = ChromadbRM('collection_name', 'db_path')
5557 dspy.settings.configure(lm=llm, rm=retriever_model)
5658 # to test the retriever with "my query"
5759 retriever_model("my query")
5860 ```
5961
62+ Use provided chromadb client
63+ ```python
64+ import chromadb
65+ llm = dspy.OpenAI(model="gpt-3.5-turbo")
66+ # say you have a chromadb running on a different port
67+ client = chromadb.HttpClient(host='localhost', port=8889)
68+ retriever_model = ChromadbRM('collection_name', 'db_path', client=client)
69+ dspy.settings.configure(lm=llm, rm=retriever_model)
70+ # to test the retriever with "my query"
71+ retriever_model("my query")
72+ ```
73+
6074 Below is a code snippet that shows how to use this in the forward() function of a module
6175 ```python
6276 self.retrieve = ChromadbRM('collection_name', 'db_path', k=num_passages)
@@ -89,7 +103,7 @@ def _init_chromadb(
89103 Args:
90104 collection_name (str): chromadb collection name
91105 persist_directory (str): chromadb persist directory
92- client (chromadb.Client): A chromadb client provided by user
106+ client (chromadb.Client): chromadb client provided by user
93107
94108 Returns: collection per collection_name
95109 """
0 commit comments