Skip to content

Commit b5c561f

Browse files
Merge pull request #831 from Anush008/fastembed-support
feat: FastEmbedVectorizer
2 parents d7dace6 + 7738d8b commit b5c561f

File tree

5 files changed

+96
-2
lines changed

5 files changed

+96
-2
lines changed

dsp/modules/sentence_vectorizer.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,48 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
203203

204204
embeddings = np.array(embeddings_list, dtype=np.float32)
205205
return embeddings
206+
207+
class FastEmbedVectorizer(BaseSentenceVectorizer):
208+
"""Sentence vectorizer implementaion using FastEmbed - https://qdrant.github.io/fastembed."""
209+
210+
def __init__(
211+
self,
212+
model_name: str = "BAAI/bge-small-en-v1.5",
213+
batch_size: int = 256,
214+
cache_dir: Optional[str] = None,
215+
threads: Optional[int] = None,
216+
parallel: Optional[int] = None,
217+
**kwargs,
218+
):
219+
"""Initialize fastembed.TextEmbedding.
220+
221+
Args:
222+
model_name (str): The name of the model to use. Defaults to `"BAAI/bge-small-en-v1.5"`.
223+
batch_size (int): Batch size for encoding. Higher values will use more memory, but be faster.\
224+
Defaults to 256.
225+
cache_dir (str, optional): The path to the model cache directory.\
226+
Can also be set using the `FASTEMBED_CACHE_PATH` env variable.
227+
threads (int, optional): The number of threads single onnxruntime session can use.
228+
parallel (int, optional): If `>1`, data-parallel encoding will be used, recommended for large datasets.\
229+
If `0`, use all available cores.\
230+
If `None`, don't use data-parallel processing, use default onnxruntime threading.\
231+
Defaults to None.
232+
**kwargs: Additional options to pass to fastembed.TextEmbedding
233+
Raises:
234+
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-small-en-v1.5.
235+
"""
236+
try:
237+
from fastembed import TextEmbedding
238+
except ImportError as e:
239+
raise ValueError(
240+
"The 'fastembed' package is not installed. Please install it with `pip install fastembed`",
241+
) from e
242+
self._batch_size = batch_size
243+
self._parallel = parallel
244+
self._model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads, **kwargs)
245+
246+
def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
247+
texts_to_vectorize = self._extract_text_from_examples(inp_examples)
248+
embeddings = self._model.embed(texts_to_vectorize, batch_size=self._batch_size, parallel=self._parallel)
249+
250+
return np.array([embedding.tolist() for embedding in embeddings], dtype=np.float32)

poetry.lock

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ docs = [
6363
"sphinx-automodapi==0.16.0",
6464
]
6565
dev = ["pytest>=6.2.5"]
66+
fastembed = ["fastembed>=0.2.0"]
6667

6768
[project.urls]
6869
homepage = "https://github.com/stanfordnlp/dspy"
@@ -96,7 +97,7 @@ requests = "^2.31.0"
9697
optuna = "^3.4.0"
9798
anthropic = { version = "^0.18.0", optional = true }
9899
chromadb = { version = "^0.4.14", optional = true }
99-
fastembed = { version = "^0.2.0", optional = true }
100+
fastembed = { version = ">=0.2.0", optional = true }
100101
marqo = { version = "*", optional = true }
101102
qdrant-client = { version = "^1.6.2", optional = true }
102103
pinecone-client = { version = "^2.2.4", optional = true }
@@ -153,6 +154,7 @@ docs = [
153154
"sphinx-reredirects",
154155
"sphinx-automodapi",
155156
]
157+
fastembed = ["fastembed"]
156158

157159
[tool.poetry.group.doc.dependencies]
158160
mkdocs = ">=1.5.3"

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"faiss-cpu": ["sentence_transformers", "faiss-cpu"],
3232
"milvus": ["pymilvus~=2.3.7"],
3333
"google-vertex-ai": ["google-cloud-aiplatform==1.43.0"],
34+
"fastembed": ["fastembed"],
3435
},
3536
classifiers=[
3637
"Development Status :: 3 - Alpha",
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from dsp.modules.sentence_vectorizer import FastEmbedVectorizer
2+
import pytest
3+
4+
from dspy.primitives.example import Example
5+
6+
# Skip the test if the 'fastembed' package is not installed
7+
pytest.importorskip("fastembed", reason="'fastembed' is not installed. Use `pip install fastembed` to install it.")
8+
9+
10+
@pytest.mark.parametrize(
11+
"n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5"), (512, "jinaai/jina-embeddings-v2-small-en")]
12+
)
13+
def test_fastembed_with_examples(n_dims, model_name):
14+
vectorizer = FastEmbedVectorizer(model_name)
15+
16+
examples = [
17+
Example(query="What's the price today?", response="The price is $10.00").with_inputs("query", "response"),
18+
Example(query="What's the weather today?", response="The weather is sunny").with_inputs("query", "response"),
19+
Example(query="Who was leading the team?", response="It was Jim. Rather enthusiastic guy.").with_inputs(
20+
"query", "response"
21+
),
22+
]
23+
24+
embeddings = vectorizer(examples)
25+
26+
assert embeddings.shape == (len(examples), n_dims)
27+
28+
29+
@pytest.mark.parametrize(
30+
"n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5"), (512, "jinaai/jina-embeddings-v2-small-en")]
31+
)
32+
def test_fastembed_with_strings(n_dims, model_name):
33+
vectorizer = FastEmbedVectorizer(model_name)
34+
35+
inputs = [
36+
"Jonathan Kent is a fictional character appearing in American comic books published by DC Comics.",
37+
"Clark Kent is a fictional character appearing in American comic books published by DC Comics.",
38+
"Martha Kent is a fictional character appearing in American comic books published by DC Comics.",
39+
]
40+
41+
embeddings = vectorizer(inputs)
42+
43+
assert embeddings.shape == (len(inputs), n_dims)

0 commit comments

Comments
 (0)