Skip to content

Commit eb27ee2

Browse files
Merge pull request #32 from Anush008/master
feat: Qdrant vectorstore support
2 parents f7e24b3 + 28c32a7 commit eb27ee2

File tree

3 files changed

+165
-6
lines changed

3 files changed

+165
-6
lines changed

mindsql/vectorstores/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .ivectorstore import IVectorstore
22
from .chromadb import ChromaDB
33
from .faiss_db import Faiss
4+
from .qdrant import Qdrant

mindsql/vectorstores/qdrant.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import json
2+
import os
3+
import uuid
4+
from typing import List
5+
6+
import pandas as pd
7+
from qdrant_client import QdrantClient
8+
from qdrant_client.http.models import Distance, VectorParams, PointStruct
9+
from sentence_transformers import SentenceTransformer
10+
11+
from . import IVectorstore
12+
13+
sentence_transformer_ef = SentenceTransformer("WhereIsAI/UAE-Large-V1")
14+
15+
16+
class Qdrant(IVectorstore):
17+
def __init__(self, config=None):
18+
if config is not None:
19+
self.embedding_function = config.get(
20+
"embedding_function", sentence_transformer_ef
21+
)
22+
self.dimension = config.get("dimension", 1024)
23+
qdrant_client_options = config.get("qdrant_client_options", {})
24+
else:
25+
self.embedding_function = sentence_transformer_ef
26+
self.dimension = 1024
27+
qdrant_client_options = {}
28+
self.client = QdrantClient(**qdrant_client_options)
29+
self._init_collections()
30+
31+
def _init_collections(self):
32+
for name in ["sql", "ddl", "documentation"]:
33+
if not self.client.collection_exists(collection_name=name):
34+
self.client.create_collection(
35+
collection_name=name,
36+
vectors_config=VectorParams(
37+
size=self.dimension, distance=Distance.COSINE
38+
),
39+
)
40+
41+
def index_question_sql(self, question: str, sql: str, **kwargs) -> str:
42+
question_sql_json = json.dumps(
43+
{"question": question, "sql": sql}, ensure_ascii=False
44+
)
45+
chunk_id = str(uuid.uuid4())
46+
vector = self.embedding_function.encode([question_sql_json])[0]
47+
self.client.upsert(
48+
collection_name="sql",
49+
points=[
50+
PointStruct(
51+
id=chunk_id, vector=vector, payload={"data": question_sql_json}
52+
)
53+
],
54+
)
55+
return chunk_id + "-sql"
56+
57+
def index_ddl(self, ddl: str, **kwargs) -> str:
58+
chunk_id = str(uuid.uuid4())
59+
table = kwargs.get("table", None)
60+
vector = self.embedding_function.encode([ddl])[0]
61+
payload = {"data": ddl}
62+
if table:
63+
payload["table_name"] = table
64+
self.client.upsert(
65+
collection_name="ddl",
66+
points=[PointStruct(id=chunk_id, vector=vector, payload=payload)],
67+
)
68+
return chunk_id + "-ddl"
69+
70+
def index_documentation(self, documentation: str, **kwargs) -> str:
71+
chunk_id = str(uuid.uuid4())
72+
vector = self.embedding_function.encode([documentation])[0]
73+
self.client.upsert(
74+
collection_name="documentation",
75+
points=[
76+
PointStruct(id=chunk_id, vector=vector, payload={"data": documentation})
77+
],
78+
)
79+
return chunk_id + "-doc"
80+
81+
def fetch_all_vectorstore_data(self, **kwargs) -> pd.DataFrame:
82+
data = []
83+
for name in ["sql", "ddl", "documentation"]:
84+
points = self.client.scroll(collection_name=name, limit=10000)[0]
85+
for point in points:
86+
payload = point.payload or {}
87+
if name == "sql":
88+
doc = json.loads(payload.get("data", "{}"))
89+
question = doc.get("question")
90+
content = doc.get("sql")
91+
else:
92+
question = None
93+
content = payload.get("data")
94+
data.append(
95+
{
96+
"id": point.id,
97+
"question": question,
98+
"content": content,
99+
"training_data_type": name,
100+
}
101+
)
102+
return pd.DataFrame(data)
103+
104+
def delete_vectorstore_data(self, item_id: str, **kwargs) -> bool:
105+
uuid_str = item_id[:-4]
106+
if item_id.endswith("-sql"):
107+
self.client.delete(collection_name="sql", points_selector=[uuid_str])
108+
return True
109+
elif item_id.endswith("-ddl"):
110+
self.client.delete(collection_name="ddl", points_selector=[uuid_str])
111+
return True
112+
elif item_id.endswith("-doc"):
113+
self.client.delete(
114+
collection_name="documentation", points_selector=[uuid_str]
115+
)
116+
return True
117+
else:
118+
return False
119+
120+
def remove_collection(self, collection_name: str) -> bool:
121+
if self.client.collection_exists(collection_name=collection_name):
122+
self.client.delete_collection(collection_name=collection_name)
123+
self.client.create_collection(
124+
collection_name=collection_name,
125+
vectors_config=VectorParams(
126+
size=self.dimension, distance=Distance.COSINE
127+
),
128+
)
129+
return True
130+
return False
131+
132+
def retrieve_relevant_question_sql(self, question: str, **kwargs) -> list:
133+
n = kwargs.get("n_results", 2)
134+
vector = self.embedding_function.encode([question])[0]
135+
hits = self.client.query_points(
136+
collection_name="sql", query=vector, limit=n
137+
).points
138+
results = []
139+
for hit in hits:
140+
doc = json.loads(hit.payload.get("data", "{}"))
141+
results.append(doc)
142+
return results
143+
144+
def retrieve_relevant_ddl(self, question: str, **kwargs) -> list:
145+
n = kwargs.get("n_results", 2)
146+
vector = self.embedding_function.encode([question])[0]
147+
hits = self.client.query_points(
148+
collection_name="ddl", query=vector, limit=n
149+
).points
150+
return [hit.payload.get("data") for hit in hits]
151+
152+
def retrieve_relevant_documentation(self, question: str, **kwargs) -> list:
153+
n = kwargs.get("n_results", 2)
154+
vector = self.embedding_function.encode([question])[0]
155+
hits = self.client.query_points(
156+
collection_name="documentation", query=vector, limit=n
157+
).points
158+
return [hit.payload.get("data") for hit in hits]

pyproject.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@ classifiers = [
1616

1717

1818
[tool.poetry.dependencies]
19-
python = "^3.10"
20-
chromadb = "^0.4.22"
21-
pandas = "2.2.0"
19+
python = "^3.11"
20+
chromadb = "^1.0.15"
21+
pandas = "2.3.1"
2222
plotly = "5.19.0"
2323
mysql-connector-python = "^8.3.0"
2424
google-generativeai="0.3.2"
2525
llama-cpp-python = "0.2.47"
2626
openai = "^1.12.0"
2727
sqlparse = "^0.4.4"
28-
numpy = "^1.26.4"
28+
numpy = "2.3.1"
2929
sentence-transformers = "^2.3.1"
3030
psycopg2-binary = "^2.9.9"
31-
faiss-cpu = "^1.8.0"
32-
pysqlite3-binary = "^0.5.2.post3"
31+
faiss-cpu = "^1.11.0.post1"
3332
transformers = "^4.38.2"
33+
qdrant-client = "^1.14.3"
3434

3535

3636
[build-system]

0 commit comments

Comments
 (0)