Skip to content

Commit 2c2faae

Browse files
author
imash
committed
Add MariaDB integration with native VECTOR support
Implement MariaDB database connector and vector store with native VECTOR(384) data type support for efficient similarity search. Includes improvements to SQL extraction and DDL retrieval.
1 parent eb27ee2 commit 2c2faae

File tree

7 files changed

+330
-0
lines changed

7 files changed

+330
-0
lines changed

mindsql/_helper/helper.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ def log_and_return(extracted_sql: str) -> str:
6262
log.info(LOG_AND_RETURN_CONSTANT.format(llm_response, extracted_sql))
6363
return extracted_sql
6464

65+
if "SQLQuery:" in llm_response:
66+
sql_part = llm_response.split("SQLQuery:", 1)[1].strip()
67+
if "\n\n" in sql_part:
68+
sql_part = sql_part.split("\n\n")[0].strip()
69+
return log_and_return(sql_part)
70+
6571
sql_match = re.search(r"```(sql)?\n(.+?)```", llm_response, re.DOTALL)
6672
if sql_match:
6773
return log_and_return(sql_match.group(2).replace("`", ""))

mindsql/_utils/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
MYSQL_SHOW_DATABASE_QUERY = "SHOW DATABASES;"
1919
MYSQL_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = '{}';"
2020
MYSQL_SHOW_CREATE_TABLE_QUERY = "SHOW CREATE TABLE `{}`;"
21+
MARIADB_SHOW_DATABASE_QUERY = "SHOW DATABASES;"
22+
MARIADB_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = '{}';"
23+
MARIADB_SHOW_CREATE_TABLE_QUERY = "SHOW CREATE TABLE `{}`;"
2124
POSTGRESQL_SHOW_DATABASE_QUERY = "SELECT datname as DATABASE_NAME FROM pg_database WHERE datistemplate = false;"
2225
POSTGRESQL_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_catalog = '{db}';"
2326
ERROR_DOWNLOADING_SQLITE_DB_CONSTANT = "Error downloading sqlite db: {}"

mindsql/core/mindsql_core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,15 @@ def __get_ddl_statements(self, connection: any, tables: list[str], question: str
176176
Returns:
177177
list[str]: The list of DDL statements.
178178
"""
179+
vector_ddls = []
180+
try:
181+
vector_ddls = self.vectorstore.retrieve_relevant_ddl(question, **kwargs)
182+
except Exception as e:
183+
log.info(f"Vector store retrieval failed: {e}")
184+
185+
if vector_ddls and len(vector_ddls) > 0:
186+
return vector_ddls
187+
179188
if tables and connection:
180189
ddl_statements = []
181190
for table_name in tables:

mindsql/databases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .idatabase import IDatabase
2+
from .mariadb import MariaDB
23
from .mysql import MySql
34
from .postgres import Postgres
45
from .sqlite import Sqlite

mindsql/databases/mariadb.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from typing import List
2+
from urllib.parse import urlparse
3+
4+
import mariadb
5+
import pandas as pd
6+
7+
from .._utils import logger
8+
from .._utils.constants import SUCCESSFULLY_CONNECTED_TO_DB_CONSTANT, ERROR_CONNECTING_TO_DB_CONSTANT, \
9+
INVALID_DB_CONNECTION_OBJECT, ERROR_WHILE_RUNNING_QUERY, MARIADB_DB_TABLES_INFO_SCHEMA_QUERY, \
10+
MARIADB_SHOW_DATABASE_QUERY, MARIADB_SHOW_CREATE_TABLE_QUERY, CONNECTION_ESTABLISH_ERROR_CONSTANT
11+
from . import IDatabase
12+
13+
log = logger.init_loggers("MariaDB")
14+
15+
16+
class MariaDB(IDatabase):
17+
def create_connection(self, url: str, **kwargs) -> any:
18+
url = urlparse(url)
19+
try:
20+
connection_params = {
21+
'host': url.hostname,
22+
'port': url.port or int(kwargs.get('port', 3306)),
23+
'user': url.username,
24+
'password': url.password,
25+
'database': url.path.lstrip('/') if url.path else None,
26+
'autocommit': True,
27+
}
28+
29+
connection_params = {k: v for k, v in connection_params.items() if v is not None}
30+
connection_params.update({k: v for k, v in kwargs.items() if k not in ['port']})
31+
32+
conn = mariadb.connect(**connection_params)
33+
log.info(SUCCESSFULLY_CONNECTED_TO_DB_CONSTANT.format("MariaDB"))
34+
return conn
35+
36+
except mariadb.Error as e:
37+
log.info(ERROR_CONNECTING_TO_DB_CONSTANT.format("MariaDB", str(e)))
38+
return None
39+
40+
def validate_connection(self, connection: any) -> None:
41+
if connection is None:
42+
raise ValueError(CONNECTION_ESTABLISH_ERROR_CONSTANT)
43+
if not hasattr(connection, 'cursor'):
44+
raise ValueError(INVALID_DB_CONNECTION_OBJECT.format("MariaDB"))
45+
46+
def execute_sql(self, connection, sql: str) -> pd.DataFrame:
47+
try:
48+
self.validate_connection(connection)
49+
cursor = connection.cursor()
50+
cursor.execute(sql)
51+
52+
if sql.strip().upper().startswith(('CREATE', 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'ALTER')):
53+
connection.commit()
54+
cursor.close()
55+
return pd.DataFrame()
56+
57+
results = cursor.fetchall()
58+
if cursor.description:
59+
column_names = [i[0] for i in cursor.description]
60+
df = pd.DataFrame(results, columns=column_names)
61+
else:
62+
df = pd.DataFrame()
63+
cursor.close()
64+
return df
65+
except mariadb.Error as e:
66+
log.info(ERROR_WHILE_RUNNING_QUERY.format(e))
67+
return pd.DataFrame()
68+
69+
def get_databases(self, connection) -> List[str]:
70+
try:
71+
self.validate_connection(connection)
72+
df_databases = self.execute_sql(connection=connection, sql=MARIADB_SHOW_DATABASE_QUERY)
73+
except Exception as e:
74+
log.info(e)
75+
return []
76+
return df_databases["Database"].unique().tolist()
77+
78+
def get_table_names(self, connection, database: str) -> pd.DataFrame:
79+
self.validate_connection(connection)
80+
df_tables = self.execute_sql(connection, MARIADB_DB_TABLES_INFO_SCHEMA_QUERY.format(database))
81+
return df_tables
82+
83+
def get_all_ddls(self, connection, database: str) -> pd.DataFrame:
84+
self.validate_connection(connection)
85+
df_tables = self.get_table_names(connection, database)
86+
df_ddl = pd.DataFrame(columns=['Table', 'DDL'])
87+
for index, row in df_tables.iterrows():
88+
table_name = row.get('TABLE_NAME') or row.get('table_name')
89+
if table_name:
90+
ddl_df = self.get_ddl(connection, table_name)
91+
df_ddl = df_ddl._append({'Table': table_name, 'DDL': ddl_df}, ignore_index=True)
92+
return df_ddl
93+
94+
def get_ddl(self, connection: any, table_name: str, **kwargs) -> str:
95+
ddl_df = self.execute_sql(connection, MARIADB_SHOW_CREATE_TABLE_QUERY.format(table_name))
96+
return ddl_df["Create Table"].iloc[0]
97+
98+
def get_dialect(self) -> str:
99+
return 'mysql'

mindsql/vectorstores/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .chromadb import ChromaDB
33
from .faiss_db import Faiss
44
from .qdrant import Qdrant
5+
from .mariadb_vector import MariaDBVectorStore
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import json
2+
import uuid
3+
from typing import List
4+
import mariadb
5+
import pandas as pd
6+
from sentence_transformers import SentenceTransformer
7+
from mindsql.vectorstores import IVectorstore
8+
9+
10+
class MariaDBVectorStore(IVectorstore):
11+
def __init__(self, config=None):
12+
if config is None:
13+
raise ValueError("MariaDB configuration is required")
14+
15+
self.collection_name = config.get('collection_name', 'mindsql_vectors')
16+
self.connection_params = {
17+
'host': config.get('host', 'localhost'),
18+
'port': config.get('port', 3306),
19+
'user': config.get('user'),
20+
'password': config.get('password'),
21+
}
22+
23+
if 'database' in config and config['database']:
24+
self.connection_params['database'] = config['database']
25+
26+
self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
27+
self.dimension = 384
28+
self._init_database()
29+
30+
def _init_database(self):
31+
try:
32+
conn = mariadb.connect(**self.connection_params)
33+
cursor = conn.cursor()
34+
35+
cursor.execute(f"DROP TABLE IF EXISTS {self.collection_name}")
36+
cursor.execute(f"DROP TABLE IF EXISTS {self.collection_name}_sql_pairs")
37+
38+
cursor.execute(f"""
39+
CREATE TABLE {self.collection_name} (
40+
id VARCHAR(36) PRIMARY KEY,
41+
document TEXT NOT NULL,
42+
embedding VECTOR({self.dimension}) NOT NULL,
43+
metadata JSON,
44+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
45+
INDEX idx_created_at (created_at),
46+
FULLTEXT(document)
47+
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
48+
""")
49+
50+
cursor.execute(f"""
51+
CREATE TABLE {self.collection_name}_sql_pairs (
52+
id VARCHAR(36) PRIMARY KEY,
53+
question TEXT NOT NULL,
54+
sql_query TEXT NOT NULL,
55+
embedding VECTOR({self.dimension}) NOT NULL,
56+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
57+
FULLTEXT(question, sql_query)
58+
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
59+
""")
60+
61+
conn.commit()
62+
cursor.close()
63+
conn.close()
64+
except Exception as e:
65+
raise RuntimeError(f"Failed to initialize MariaDB vector store: {e}")
66+
67+
def _format_vector_for_insertion(self, embedding_array):
68+
if len(embedding_array) != self.dimension:
69+
raise ValueError(f"Expected {self.dimension} dimensions, got {len(embedding_array)}")
70+
return '[' + ','.join(f'{float(x)}' for x in embedding_array) + ']'
71+
72+
def add_ddl(self, ddl: str):
73+
embedding = self.embedding_model.encode(ddl).tolist()
74+
vector_json = self._format_vector_for_insertion(embedding)
75+
76+
conn = mariadb.connect(**self.connection_params)
77+
cursor = conn.cursor()
78+
doc_id = str(uuid.uuid4())
79+
cursor.execute(f"""
80+
INSERT INTO {self.collection_name}
81+
(id, document, embedding, metadata)
82+
VALUES (?, ?, VEC_FromText(?), ?)
83+
""", (doc_id, ddl, vector_json, json.dumps({"type": "ddl"})))
84+
conn.commit()
85+
cursor.close()
86+
conn.close()
87+
88+
def add_question_sql(self, question: str, sql: str):
89+
embedding = self.embedding_model.encode(question).tolist()
90+
vector_json = self._format_vector_for_insertion(embedding)
91+
92+
conn = mariadb.connect(**self.connection_params)
93+
cursor = conn.cursor()
94+
doc_id = str(uuid.uuid4())
95+
cursor.execute(f"""
96+
INSERT INTO {self.collection_name}_sql_pairs
97+
(id, question, sql_query, embedding)
98+
VALUES (?, ?, ?, VEC_FromText(?))
99+
""", (doc_id, question, sql, vector_json))
100+
conn.commit()
101+
cursor.close()
102+
conn.close()
103+
104+
def get_similar_question_sql(self, question: str, n_results: int = 5):
105+
conn = mariadb.connect(**self.connection_params)
106+
cursor = conn.cursor()
107+
cursor.execute(f"""
108+
SELECT question, sql_query,
109+
MATCH(question, sql_query) AGAINST (? IN NATURAL LANGUAGE MODE) as text_score
110+
FROM {self.collection_name}_sql_pairs
111+
WHERE MATCH(question, sql_query) AGAINST (? IN NATURAL LANGUAGE MODE)
112+
ORDER BY text_score DESC LIMIT ?
113+
""", (question, question, n_results))
114+
results = cursor.fetchall()
115+
cursor.close()
116+
conn.close()
117+
118+
return [{'question': r[0], 'sql': r[1], 'similarity': r[2], 'text_score': r[2]} for r in results]
119+
120+
def retrieve_relevant_ddl(self, question: str, **kwargs) -> list:
121+
conn = mariadb.connect(**self.connection_params)
122+
cursor = conn.cursor()
123+
cursor.execute(f"""
124+
SELECT document FROM {self.collection_name}
125+
WHERE JSON_EXTRACT(metadata, '$.type') = 'ddl'
126+
ORDER BY created_at DESC LIMIT ?
127+
""", (kwargs.get('n_results', 5),))
128+
results = cursor.fetchall()
129+
cursor.close()
130+
conn.close()
131+
return [row[0] for row in results]
132+
133+
def retrieve_relevant_documentation(self, question: str, **kwargs) -> list:
134+
conn = mariadb.connect(**self.connection_params)
135+
cursor = conn.cursor()
136+
cursor.execute(f"""
137+
SELECT document FROM {self.collection_name}
138+
WHERE JSON_EXTRACT(metadata, '$.type') = 'documentation'
139+
ORDER BY created_at DESC LIMIT ?
140+
""", (kwargs.get('n_results', 5),))
141+
results = cursor.fetchall()
142+
cursor.close()
143+
conn.close()
144+
return [row[0] for row in results]
145+
146+
def retrieve_relevant_question_sql(self, question: str, **kwargs) -> list:
147+
return self.get_similar_question_sql(question, kwargs.get('n_results', 3))
148+
149+
def index_question_sql(self, question: str, sql: str, **kwargs) -> str:
150+
try:
151+
self.add_question_sql(question, sql)
152+
return "Successfully added question-SQL pair"
153+
except Exception as e:
154+
return f"Failed: {e}"
155+
156+
def index_ddl(self, ddl: str, **kwargs) -> str:
157+
try:
158+
self.add_ddl(ddl)
159+
return "Successfully added DDL"
160+
except Exception as e:
161+
return f"Failed: {e}"
162+
163+
def index_documentation(self, documentation: str, **kwargs) -> str:
164+
try:
165+
embedding = self.embedding_model.encode(documentation).tolist()
166+
vector_json = self._format_vector_for_insertion(embedding)
167+
168+
conn = mariadb.connect(**self.connection_params)
169+
cursor = conn.cursor()
170+
doc_id = str(uuid.uuid4())
171+
cursor.execute(f"""
172+
INSERT INTO {self.collection_name}
173+
(id, document, embedding, metadata)
174+
VALUES (?, ?, VEC_FromText(?), ?)
175+
""", (doc_id, documentation, vector_json, json.dumps({"type": "documentation"})))
176+
conn.commit()
177+
cursor.close()
178+
conn.close()
179+
return "Successfully added documentation"
180+
except Exception as e:
181+
return f"Failed: {e}"
182+
183+
def fetch_all_vectorstore_data(self, **kwargs) -> pd.DataFrame:
184+
conn = mariadb.connect(**self.connection_params)
185+
main_df = pd.read_sql(f"SELECT id, document, created_at FROM {self.collection_name}", conn)
186+
sql_pairs_df = pd.read_sql(f"SELECT id, question, sql_query, created_at FROM {self.collection_name}_sql_pairs", conn)
187+
conn.close()
188+
189+
data = []
190+
for _, row in main_df.iterrows():
191+
data.append({'id': row['id'], 'content': row['document'], 'type': 'document', 'created_at': row['created_at']})
192+
for _, row in sql_pairs_df.iterrows():
193+
data.append({'id': row['id'], 'content': f"Q: {row['question']} | SQL: {row['sql_query']}",
194+
'type': 'question_sql', 'created_at': row['created_at']})
195+
return pd.DataFrame(data)
196+
197+
def delete_vectorstore_data(self, item_id: str, **kwargs) -> bool:
198+
conn = mariadb.connect(**self.connection_params)
199+
cursor = conn.cursor()
200+
cursor.execute(f"DELETE FROM {self.collection_name} WHERE id = ?", (item_id,))
201+
main_deleted = cursor.rowcount
202+
cursor.execute(f"DELETE FROM {self.collection_name}_sql_pairs WHERE id = ?", (item_id,))
203+
pairs_deleted = cursor.rowcount
204+
conn.commit()
205+
cursor.close()
206+
conn.close()
207+
return (main_deleted + pairs_deleted) > 0
208+
209+
def add_documents(self, documents: List[str]):
210+
for doc in documents:
211+
self.add_ddl(doc)

0 commit comments

Comments
 (0)