Skip to content

Commit 20f530e

Browse files
Add Sentence Transformers Embeddings (langchain-ai#3409)
Add embeddings based on the sentence transformers library. Add a notebook and integration tests. Co-authored-by: khimaros <me@khimaros.com>
1 parent 73bc70b commit 20f530e

File tree

5 files changed

+226
-1
lines changed

5 files changed

+226
-1
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
{
2+
"cells": [
3+
{
4+
"attachments": {},
5+
"cell_type": "markdown",
6+
"id": "ed47bb62",
7+
"metadata": {},
8+
"source": [
9+
"# Sentence Transformers Embeddings\n",
10+
"\n",
11+
"Let's generate embeddings using the [SentenceTransformers](https://www.sbert.net/) integration. SentenceTransformers is a python package that can generate text and image embeddings, originating from [Sentence-BERT](https://arxiv.org/abs/1908.10084)"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": 7,
17+
"id": "06c9f47d",
18+
"metadata": {},
19+
"outputs": [
20+
{
21+
"name": "stdout",
22+
"output_type": "stream",
23+
"text": [
24+
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
25+
"To disable this warning, you can either:\n",
26+
"\t- Avoid using `tokenizers` before the fork if possible\n",
27+
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
28+
]
29+
}
30+
],
31+
"source": [
32+
"!pip install sentence_transformers > /dev/null"
33+
]
34+
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": 8,
38+
"id": "861521a9",
39+
"metadata": {},
40+
"outputs": [],
41+
"source": [
42+
"from langchain.embeddings import SentenceTransformerEmbeddings "
43+
]
44+
},
45+
{
46+
"cell_type": "code",
47+
"execution_count": 9,
48+
"id": "ff9be586",
49+
"metadata": {},
50+
"outputs": [],
51+
"source": [
52+
"embeddings = SentenceTransformerEmbeddings(model=\"all-MiniLM-L6-v2\")"
53+
]
54+
},
55+
{
56+
"cell_type": "code",
57+
"execution_count": 10,
58+
"id": "d0a98ae9",
59+
"metadata": {},
60+
"outputs": [],
61+
"source": [
62+
"text = \"This is a test document.\""
63+
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": 11,
68+
"id": "5d6c682b",
69+
"metadata": {},
70+
"outputs": [],
71+
"source": [
72+
"query_result = embeddings.embed_query(text)"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": 12,
78+
"id": "bb5e74c0",
79+
"metadata": {},
80+
"outputs": [],
81+
"source": [
82+
"doc_result = embeddings.embed_documents([text, \"This is not a test document.\"])"
83+
]
84+
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": null,
88+
"id": "aaad49f8",
89+
"metadata": {},
90+
"outputs": [],
91+
"source": []
92+
}
93+
],
94+
"metadata": {
95+
"kernelspec": {
96+
"display_name": "Python 3 (ipykernel)",
97+
"language": "python",
98+
"name": "python3"
99+
},
100+
"language_info": {
101+
"codemirror_mode": {
102+
"name": "ipython",
103+
"version": 3
104+
},
105+
"file_extension": ".py",
106+
"mimetype": "text/x-python",
107+
"name": "python",
108+
"nbconvert_exporter": "python",
109+
"pygments_lexer": "ipython3",
110+
"version": "3.11.2"
111+
},
112+
"vscode": {
113+
"interpreter": {
114+
"hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885"
115+
}
116+
}
117+
},
118+
"nbformat": 4,
119+
"nbformat_minor": 5
120+
}

langchain/embeddings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
SelfHostedHuggingFaceEmbeddings,
2323
SelfHostedHuggingFaceInstructEmbeddings,
2424
)
25+
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
2526
from langchain.embeddings.tensorflow_hub import TensorflowHubEmbeddings
2627

2728
logger = logging.getLogger(__name__)
@@ -42,6 +43,7 @@
4243
"FakeEmbeddings",
4344
"AlephAlphaAsymmetricSemanticEmbedding",
4445
"AlephAlphaSymmetricSemanticEmbedding",
46+
"SentenceTransformerEmbeddings",
4547
]
4648

4749

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Wrapper around sentence transformer embedding models."""
2+
from typing import Any, Dict, List, Optional
3+
4+
from pydantic import BaseModel, Extra, Field, root_validator
5+
6+
from langchain.embeddings.base import Embeddings
7+
8+
9+
class SentenceTransformerEmbeddings(BaseModel, Embeddings):
10+
embedding_function: Any #: :meta private:
11+
12+
model: Optional[str] = Field("all-MiniLM-L6-v2", alias="model")
13+
"""Transformer model to use."""
14+
15+
class Config:
16+
"""Configuration for this pydantic object."""
17+
18+
extra = Extra.forbid
19+
20+
@root_validator()
21+
def validate_environment(cls, values: Dict) -> Dict:
22+
"""Validate that sentence_transformers library is installed."""
23+
model = values["model"]
24+
25+
try:
26+
from sentence_transformers import SentenceTransformer
27+
28+
values["embedding_function"] = SentenceTransformer(model)
29+
except ImportError:
30+
raise ModuleNotFoundError(
31+
"Could not import sentence_transformers library. "
32+
"Please install the sentence_transformers library to "
33+
"use this embedding model: pip install sentence_transformers"
34+
)
35+
except Exception:
36+
raise NameError(f"Could not load SentenceTransformer model {model}.")
37+
38+
return values
39+
40+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
41+
"""Embed a list of documents using the SentenceTransformer model.
42+
43+
Args:
44+
texts: The list of texts to embed.
45+
46+
Returns:
47+
List of embeddings, one for each text.
48+
"""
49+
embeddings = self.embedding_function.encode(
50+
texts, convert_to_numpy=True
51+
).tolist()
52+
return [list(map(float, e)) for e in embeddings]
53+
54+
def embed_query(self, text: str) -> List[float]:
55+
"""Embed a query using the SentenceTransformer model.
56+
57+
Args:
58+
text: The text to embed.
59+
60+
Returns:
61+
Embedding for the text.
62+
"""
63+
return self.embed_documents([text])[0]

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ torch = "^1.0.0"
117117
chromadb = "^0.3.21"
118118
tiktoken = "^0.3.3"
119119
python-dotenv = "^1.0.0"
120+
sentence-transformers = "^2"
120121
gptcache = "^0.1.9"
121122
promptlayer = "^0.1.80"
122123

@@ -144,7 +145,8 @@ llms = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifes
144145
qdrant = ["qdrant-client"]
145146
openai = ["openai"]
146147
cohere = ["cohere"]
147-
all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "boto3", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect"]
148+
embeddings = ["sentence-transformers"]
149+
all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "boto3", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect"]
148150

149151
[tool.ruff]
150152
select = [
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# flake8: noqa
2+
"""Test sentence_transformer embeddings."""
3+
4+
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
5+
from langchain.vectorstores import Chroma
6+
7+
8+
def test_sentence_transformer_embedding_documents() -> None:
9+
"""Test sentence_transformer embeddings."""
10+
embedding = SentenceTransformerEmbeddings()
11+
documents = ["foo bar"]
12+
output = embedding.embed_documents(documents)
13+
assert len(output) == 1
14+
assert len(output[0]) == 384
15+
16+
17+
def test_sentence_transformer_embedding_query() -> None:
18+
"""Test sentence_transformer embeddings."""
19+
embedding = SentenceTransformerEmbeddings()
20+
query = "what the foo is a bar?"
21+
query_vector = embedding.embed_query(query)
22+
assert len(query_vector) == 384
23+
24+
25+
def test_sentence_transformer_db_query() -> None:
26+
"""Test sentence_transformer similarity search."""
27+
embedding = SentenceTransformerEmbeddings()
28+
texts = [
29+
"we will foo your bar until you can't foo any more",
30+
"the quick brown fox jumped over the lazy dog",
31+
]
32+
query = "what the foo is a bar?"
33+
query_vector = embedding.embed_query(query)
34+
assert len(query_vector) == 384
35+
db = Chroma(embedding_function=embedding)
36+
db.add_texts(texts)
37+
docs = db.similarity_search_by_vector(query_vector, k=2)
38+
assert docs[0].page_content == "we will foo your bar until you can't foo any more"

0 commit comments

Comments
 (0)