Skip to content

Commit 4c00977

Browse files
Merge pull request #874 from XiaoConstantine/xiao/chroma-client
feat: Provide optional client for ChromadbRM
2 parents d09d984 + 76dd124 commit 4c00977

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

dspy/retrieve/chromadb_rm.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@
2626
from chromadb.config import Settings
2727
from chromadb.utils import embedding_functions
2828
except ImportError:
29-
chromadb = None
30-
31-
if chromadb is None:
3229
raise ImportError(
3330
"The chromadb library is required to use ChromadbRM. Install it with `pip install dspy-ai[chromadb]`",
3431
)
@@ -46,6 +43,7 @@ class ChromadbRM(dspy.Retrieve):
4643
persist_directory (str): chromadb persist directory
4744
embedding_function (Optional[EmbeddingFunction[Embeddable]]): Optional function to use to embed documents. Defaults to DefaultEmbeddingFunction.
4845
k (int, optional): The number of top passages to retrieve. Defaults to 7.
46+
client(Optional[chromadb.Client]): Optional chromadb client provided by user, default to None
4947
5048
Returns:
5149
dspy.Prediction: An object containing the retrieved passages.
@@ -54,12 +52,25 @@ class ChromadbRM(dspy.Retrieve):
5452
Below is a code snippet that shows how to use this as the default retriever:
5553
```python
5654
llm = dspy.OpenAI(model="gpt-3.5-turbo")
55+
# using default chromadb client
5756
retriever_model = ChromadbRM('collection_name', 'db_path')
5857
dspy.settings.configure(lm=llm, rm=retriever_model)
5958
# to test the retriever with "my query"
6059
retriever_model("my query")
6160
```
6261
62+
Use provided chromadb client
63+
```python
64+
import chromadb
65+
llm = dspy.OpenAI(model="gpt-3.5-turbo")
66+
# say you have a chromadb running on a different port
67+
client = chromadb.HttpClient(host='localhost', port=8889)
68+
retriever_model = ChromadbRM('collection_name', 'db_path', client=client)
69+
dspy.settings.configure(lm=llm, rm=retriever_model)
70+
# to test the retriever with "my query"
71+
retriever_model("my query")
72+
```
73+
6374
Below is a code snippet that shows how to use this in the forward() function of a module
6475
```python
6576
self.retrieve = ChromadbRM('collection_name', 'db_path', k=num_passages)
@@ -73,9 +84,10 @@ def __init__(
7384
embedding_function: Optional[
7485
EmbeddingFunction[Embeddable]
7586
] = ef.DefaultEmbeddingFunction(),
87+
client: Optional[chromadb.Client] = None,
7688
k: int = 7,
7789
):
78-
self._init_chromadb(collection_name, persist_directory)
90+
self._init_chromadb(collection_name, persist_directory, client=client)
7991
self.ef = embedding_function
8092

8193
super().__init__(k=k)
@@ -84,22 +96,26 @@ def _init_chromadb(
8496
self,
8597
collection_name: str,
8698
persist_directory: str,
99+
client: Optional[chromadb.Client] = None,
87100
) -> chromadb.Collection:
88101
"""Initialize chromadb and return the loaded index.
89102
90103
Args:
91104
collection_name (str): chromadb collection name
92105
persist_directory (str): chromadb persist directory
106+
client (chromadb.Client): chromadb client provided by user
93107
94-
95-
Returns:
108+
Returns: collection per collection_name
96109
"""
97110

98-
self._chromadb_client = chromadb.Client(
99-
Settings(
100-
persist_directory=persist_directory,
101-
is_persistent=True,
102-
),
111+
if client:
112+
self._chromadb_client = client
113+
else:
114+
self._chromadb_client = chromadb.Client(
115+
Settings(
116+
persist_directory=persist_directory,
117+
is_persistent=True,
118+
),
103119
)
104120
self._chromadb_collection = self._chromadb_client.get_or_create_collection(
105121
name=collection_name,

0 commit comments

Comments
 (0)