Skip to content

Commit 6eb0403

Browse files
author
imash
committed
Add MariaDB integration with native VECTOR support
Implement MariaDB database connector and vector store with native VECTOR(384) data type support for efficient similarity search. Includes improvements to SQL extraction and DDL retrieval.
1 parent eb27ee2 commit 6eb0403

File tree

8 files changed

+347
-9
lines changed

8 files changed

+347
-9
lines changed

mindsql/_helper/helper.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,21 @@ def log_and_return(extracted_sql: str) -> str:
6262
log.info(LOG_AND_RETURN_CONSTANT.format(llm_response, extracted_sql))
6363
return extracted_sql
6464

65+
if "SQLQuery:" in llm_response:
66+
sql_part = llm_response.split("SQLQuery:", 1)[1].strip()
67+
if "\n\n" in sql_part:
68+
sql_part = sql_part.split("\n\n")[0].strip()
69+
return log_and_return(sql_part)
70+
6571
sql_match = re.search(r"```(sql)?\n(.+?)```", llm_response, re.DOTALL)
6672
if sql_match:
6773
return log_and_return(sql_match.group(2).replace("`", ""))
74+
6875
elif has_select_and_semicolon(llm_response):
6976
start_sql = llm_response.find("SELECT")
7077
end_sql = llm_response.find(";")
7178
return log_and_return(llm_response[start_sql:end_sql + 1].replace("`", ""))
79+
7280
return llm_response
7381

7482

mindsql/_utils/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
MYSQL_SHOW_DATABASE_QUERY = "SHOW DATABASES;"
1919
MYSQL_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = '{}';"
2020
MYSQL_SHOW_CREATE_TABLE_QUERY = "SHOW CREATE TABLE `{}`;"
21+
MARIADB_SHOW_DATABASE_QUERY = "SHOW DATABASES;"
22+
MARIADB_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = '{}';"
23+
MARIADB_SHOW_CREATE_TABLE_QUERY = "SHOW CREATE TABLE `{}`;"
2124
POSTGRESQL_SHOW_DATABASE_QUERY = "SELECT datname as DATABASE_NAME FROM pg_database WHERE datistemplate = false;"
2225
POSTGRESQL_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_catalog = '{db}';"
2326
ERROR_DOWNLOADING_SQLITE_DB_CONSTANT = "Error downloading sqlite db: {}"

mindsql/_utils/prompts.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
DEFAULT_PROMPT: str = """As a {dialect_name} expert, your task is to generate SQL queries based on user questions. Ensure that your {dialect_name} queries are syntactically correct and tailored to the user's inquiry. Retrieve at most 10 results using the LIMIT clause and order them for relevance. Avoid querying for all columns from a table. Select only the necessary columns wrapped in backticks (`). Use CURDATE() to handle 'today' queries and employ the LIKE clause for precise matches in {dialect_name}. Carefully consider column names and their respective tables to avoid querying non-existent columns. Stop after delivering the SQLQuery, avoiding follow-up questions.
1+
DEFAULT_PROMPT: str = """As a {dialect_name} expert, your task is to generate accurate SQL queries based on user questions and the provided table schemas.
22
3-
Follow this format:
3+
Make sure you look at the DDL statements to understand what columns are available in each table. When someone asks to "show" or "display" a specific table, just select from that table directly. Only query columns that actually exist - this is really important because querying non-existent columns will cause errors. Pick the columns that make sense for answering what the user is asking about.
4+
5+
Always wrap your table and column names with backticks (`) since that's the {dialect_name} way of handling identifiers. Try to add LIMIT 10 to your queries so we don't return too much data at once. If it makes sense, use ORDER BY to sort the results in a logical way. For queries about "today", you can use the CURDATE() function. Also make sure any filter values match exactly what's in the schema.
6+
7+
Your response should follow this format:
48
Question: User's question here
59
SQLQuery: Your SQL query without preamble
610
@@ -54,6 +58,5 @@
5458
- Ensure that the code is well-commented for readability and syntactically correct.
5559
"""
5660

57-
SQL_EXCEPTION_RESPONSE = """Apologies for the inconvenience! 🙏 It seems the database is currently experiencing a bit
58-
of a hiccup and isn't cooperating as we'd like. 🤖"""
61+
SQL_EXCEPTION_RESPONSE = """Apologies for the inconvenience! It seems the database is currently experiencing a bit of a hiccup and isn't cooperating as we'd like."""
5962

mindsql/core/mindsql_core.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def create_database_query(self, question: str, connection, tables: list, **kwarg
4949
question_sql_list = self.vectorstore.retrieve_relevant_question_sql(question, **kwargs)
5050
prompt = self.build_sql_prompt(question=question, connection=connection, question_sql_list=question_sql_list,
5151
tables=tables, **kwargs)
52-
log.info(prompt)
52+
# log.info(prompt)
5353
llm_response = self.llm.invoke(prompt, **kwargs)
5454
return _helper.helper.extract_sql(llm_response)
5555

@@ -176,13 +176,25 @@ def __get_ddl_statements(self, connection: any, tables: list[str], question: str
176176
Returns:
177177
list[str]: The list of DDL statements.
178178
"""
179+
vector_ddls = []
180+
try:
181+
vector_ddls = self.vectorstore.retrieve_relevant_ddl(question, **kwargs)
182+
except Exception as e:
183+
log.info(f"Vector store retrieval failed: {e}")
184+
185+
if vector_ddls and len(vector_ddls) > 0:
186+
return vector_ddls
187+
179188
if tables and connection:
180189
ddl_statements = []
181190
for table_name in tables:
182-
ddl_statements.append(self.database.get_ddl(connection=connection, table_name=table_name))
183-
else:
184-
ddl_statements = self.vectorstore.retrieve_relevant_ddl(question, **kwargs)
185-
return ddl_statements
191+
try:
192+
ddl_statements.append(self.database.get_ddl(connection=connection, table_name=table_name))
193+
except Exception as e:
194+
log.info(f"Failed to get DDL for table {table_name}: {e}")
195+
return ddl_statements
196+
197+
return []
186198

187199
def ask_db(self, connection, question: Union[str, None] = None, table_names: list = None, visualize: bool = False,
188200
**kwargs) -> dict:

mindsql/databases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .idatabase import IDatabase
2+
from .mariadb import MariaDB
23
from .mysql import MySql
34
from .postgres import Postgres
45
from .sqlite import Sqlite

mindsql/databases/mariadb.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from typing import List
2+
from urllib.parse import urlparse
3+
4+
import mariadb
5+
import pandas as pd
6+
7+
from .._utils import logger
8+
from .._utils.constants import SUCCESSFULLY_CONNECTED_TO_DB_CONSTANT, ERROR_CONNECTING_TO_DB_CONSTANT, \
9+
INVALID_DB_CONNECTION_OBJECT, ERROR_WHILE_RUNNING_QUERY, MARIADB_DB_TABLES_INFO_SCHEMA_QUERY, \
10+
MARIADB_SHOW_DATABASE_QUERY, MARIADB_SHOW_CREATE_TABLE_QUERY, CONNECTION_ESTABLISH_ERROR_CONSTANT
11+
from . import IDatabase
12+
13+
log = logger.init_loggers("MariaDB")
14+
15+
16+
class MariaDB(IDatabase):
17+
def create_connection(self, url: str, **kwargs) -> any:
18+
url = urlparse(url)
19+
try:
20+
connection_params = {
21+
'host': url.hostname,
22+
'port': url.port or int(kwargs.get('port', 3306)),
23+
'user': url.username,
24+
'password': url.password,
25+
'database': url.path.lstrip('/') if url.path else None,
26+
'autocommit': True,
27+
}
28+
29+
connection_params = {k: v for k, v in connection_params.items() if v is not None}
30+
connection_params.update({k: v for k, v in kwargs.items() if k not in ['port']})
31+
32+
conn = mariadb.connect(**connection_params)
33+
log.info(SUCCESSFULLY_CONNECTED_TO_DB_CONSTANT.format("MariaDB"))
34+
return conn
35+
36+
except mariadb.Error as e:
37+
log.info(ERROR_CONNECTING_TO_DB_CONSTANT.format("MariaDB", str(e)))
38+
return None
39+
40+
def validate_connection(self, connection: any) -> None:
41+
if connection is None:
42+
raise ValueError(CONNECTION_ESTABLISH_ERROR_CONSTANT)
43+
if not hasattr(connection, 'cursor'):
44+
raise ValueError(INVALID_DB_CONNECTION_OBJECT.format("MariaDB"))
45+
46+
def execute_sql(self, connection, sql: str) -> pd.DataFrame:
47+
try:
48+
self.validate_connection(connection)
49+
cursor = connection.cursor()
50+
cursor.execute(sql)
51+
52+
if sql.strip().upper().startswith(('CREATE', 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'ALTER')):
53+
connection.commit()
54+
cursor.close()
55+
return pd.DataFrame()
56+
57+
results = cursor.fetchall()
58+
if cursor.description:
59+
column_names = [i[0] for i in cursor.description]
60+
df = pd.DataFrame(results, columns=column_names)
61+
else:
62+
df = pd.DataFrame()
63+
cursor.close()
64+
return df
65+
except mariadb.Error as e:
66+
log.info(ERROR_WHILE_RUNNING_QUERY.format(e))
67+
return pd.DataFrame()
68+
69+
def get_databases(self, connection) -> List[str]:
70+
try:
71+
self.validate_connection(connection)
72+
df_databases = self.execute_sql(connection=connection, sql=MARIADB_SHOW_DATABASE_QUERY)
73+
except Exception as e:
74+
log.info(e)
75+
return []
76+
return df_databases["Database"].unique().tolist()
77+
78+
def get_table_names(self, connection, database: str) -> pd.DataFrame:
79+
self.validate_connection(connection)
80+
df_tables = self.execute_sql(connection, MARIADB_DB_TABLES_INFO_SCHEMA_QUERY.format(database))
81+
return df_tables
82+
83+
def get_all_ddls(self, connection, database: str) -> pd.DataFrame:
84+
self.validate_connection(connection)
85+
df_tables = self.get_table_names(connection, database)
86+
df_ddl = pd.DataFrame(columns=['Table', 'DDL'])
87+
for index, row in df_tables.iterrows():
88+
table_name = row.get('TABLE_NAME') or row.get('table_name')
89+
if table_name:
90+
ddl_df = self.get_ddl(connection, table_name)
91+
df_ddl = df_ddl._append({'Table': table_name, 'DDL': ddl_df}, ignore_index=True)
92+
return df_ddl
93+
94+
def get_ddl(self, connection: any, table_name: str, **kwargs) -> str:
95+
ddl_df = self.execute_sql(connection, MARIADB_SHOW_CREATE_TABLE_QUERY.format(table_name))
96+
return ddl_df["Create Table"].iloc[0]
97+
98+
def get_dialect(self) -> str:
99+
return 'mysql'

mindsql/vectorstores/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .chromadb import ChromaDB
33
from .faiss_db import Faiss
44
from .qdrant import Qdrant
5+
from .mariadb_vector import MariaDBVectorStore

0 commit comments

Comments
 (0)