Skip to content

Commit 8485953

Browse files
author
imash
committed
Add MariaDB integration with native VECTOR(384) support
- Add MariaDB database connector using official mariadb Python connector - Add MariaDB Vector Store with native VECTOR(384) data type - Improve DDL retrieval with vector search prioritization - Enhance SQL extraction and prompt engineering - Unify embedding model (all-MiniLM-L6-v2) across all vectorstores - Add configurable model support for Google Gemini - Update README with MariaDB documentation Features: - Native VECTOR(384) support in MariaDB 10.7+ - Hybrid vector-relational storage - FULLTEXT indexing for enhanced search - Backward compatible with existing code
1 parent eb27ee2 commit 8485953

File tree

13 files changed

+647
-20
lines changed

13 files changed

+647
-20
lines changed

README.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# 🧠 MindSQL
22

3-
MindSQL is a Python RAG (Retrieval-Augmented Generation) Library designed to streamline the interaction between users and their databases using just a few lines of code. With seamless integration for renowned databases such as PostgreSQL, MySQL, and SQLite, MindSQL also extends its capabilities to major databases like Snowflake and BigQuery by extending the `IDatabase` Interface. This library utilizes large language models (LLM) like GPT-4, Llama 2, Google Gemini, and supports knowledge bases like ChromaDB and Faiss.
3+
MindSQL is a Python RAG (Retrieval-Augmented Generation) Library designed to streamline the interaction between users and their databases using just a few lines of code. With seamless integration for renowned databases such as PostgreSQL, MySQL, MariaDB, and SQLite, MindSQL also extends its capabilities to major databases like Snowflake and BigQuery by extending the `IDatabase` Interface. This library utilizes large language models (LLM) like GPT-4, Llama 2, Google Gemini, and supports vector stores like ChromaDB, FAISS, Qdrant, and MariaDB Vector (with native VECTOR data type support).
44

55
![MindSQL Chart](https://github.com/Sammindinventory/MindSQL/assets/77489054/bc993117-8da9-4b4f-b217-8a33db65c342)
66

@@ -107,7 +107,4 @@ We value your feedback and strive to improve MindSQL. Here's how you can share y
107107
- Open an issue to provide general feedback, suggestions, or comments.
108108
- Be constructive and specific in your feedback to help us understand your perspective better.
109109

110-
Thank you for your interest in contributing to our project! We appreciate your support and look forward to working with you. 🚀
111-
112-
113-
110+
Thank you for your interest in contributing to our project! We appreciate your support and look forward to working with you. 🚀

mindsql/_helper/helper.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,26 @@ def log_and_return(extracted_sql: str) -> str:
6262
log.info(LOG_AND_RETURN_CONSTANT.format(llm_response, extracted_sql))
6363
return extracted_sql
6464

65+
# Check for SQLQuery: label (common LLM format)
66+
if "SQLQuery:" in llm_response:
67+
# Extract everything after SQLQuery:
68+
sql_part = llm_response.split("SQLQuery:", 1)[1].strip()
69+
# Remove any trailing text after the query
70+
if "\n\n" in sql_part:
71+
sql_part = sql_part.split("\n\n")[0].strip()
72+
return log_and_return(sql_part)
73+
74+
# Check for SQL in code blocks
6575
sql_match = re.search(r"```(sql)?\n(.+?)```", llm_response, re.DOTALL)
6676
if sql_match:
6777
return log_and_return(sql_match.group(2).replace("`", ""))
78+
79+
# Check for SELECT statements
6880
elif has_select_and_semicolon(llm_response):
6981
start_sql = llm_response.find("SELECT")
7082
end_sql = llm_response.find(";")
7183
return log_and_return(llm_response[start_sql:end_sql + 1].replace("`", ""))
84+
7285
return llm_response
7386

7487

mindsql/_utils/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
MYSQL_SHOW_DATABASE_QUERY = "SHOW DATABASES;"
1919
MYSQL_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = '{}';"
2020
MYSQL_SHOW_CREATE_TABLE_QUERY = "SHOW CREATE TABLE `{}`;"
21+
MARIADB_SHOW_DATABASE_QUERY = "SHOW DATABASES;"
22+
MARIADB_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = '{}';"
23+
MARIADB_SHOW_CREATE_TABLE_QUERY = "SHOW CREATE TABLE `{}`;"
2124
POSTGRESQL_SHOW_DATABASE_QUERY = "SELECT datname as DATABASE_NAME FROM pg_database WHERE datistemplate = false;"
2225
POSTGRESQL_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_catalog = '{db}';"
2326
ERROR_DOWNLOADING_SQLITE_DB_CONSTANT = "Error downloading sqlite db: {}"

mindsql/_utils/prompts.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
1-
DEFAULT_PROMPT: str = """As a {dialect_name} expert, your task is to generate SQL queries based on user questions. Ensure that your {dialect_name} queries are syntactically correct and tailored to the user's inquiry. Retrieve at most 10 results using the LIMIT clause and order them for relevance. Avoid querying for all columns from a table. Select only the necessary columns wrapped in backticks (`). Use CURDATE() to handle 'today' queries and employ the LIKE clause for precise matches in {dialect_name}. Carefully consider column names and their respective tables to avoid querying non-existent columns. Stop after delivering the SQLQuery, avoiding follow-up questions.
2-
3-
Follow this format:
1+
DEFAULT_PROMPT: str = """As a {dialect_name} expert, your task is to generate accurate SQL queries based on user questions and the provided table schemas.
2+
3+
CRITICAL INSTRUCTIONS:
4+
1. Carefully analyze the table schemas provided in the DDL statements below
5+
2. When user asks to "show TABLE_NAME" or "display TABLE_NAME table", select from that specific table
6+
3. Examine which columns exist in which tables - only query columns that actually exist
7+
4. Select meaningful, relevant columns that answer the user's question (avoid unnecessary columns)
8+
5. Use backticks (`) to wrap table and column names for {dialect_name} compatibility
9+
6. Add LIMIT clause (maximum 10 rows) to prevent excessive results
10+
7. Use ORDER BY to organize results logically
11+
8. Use CURDATE() function for queries involving "today"
12+
9. Match filter values exactly as they appear in the schema (case-sensitive)
13+
10. Double-check your table and column names against the provided DDL before generating SQL
14+
15+
Follow this exact format:
416
Question: User's question here
517
SQLQuery: Your SQL query without preamble
618
@@ -54,6 +66,5 @@
5466
- Ensure that the code is well-commented for readability and syntactically correct.
5567
"""
5668

57-
SQL_EXCEPTION_RESPONSE = """Apologies for the inconvenience! 🙏 It seems the database is currently experiencing a bit
58-
of a hiccup and isn't cooperating as we'd like. 🤖"""
69+
SQL_EXCEPTION_RESPONSE = """Apologies for the inconvenience! It seems the database is currently experiencing a bit of a hiccup and isn't cooperating as we'd like."""
5970

mindsql/core/mindsql_core.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def create_database_query(self, question: str, connection, tables: list, **kwarg
4949
question_sql_list = self.vectorstore.retrieve_relevant_question_sql(question, **kwargs)
5050
prompt = self.build_sql_prompt(question=question, connection=connection, question_sql_list=question_sql_list,
5151
tables=tables, **kwargs)
52-
log.info(prompt)
52+
# log.info(prompt) # Don't show full prompt to users
5353
llm_response = self.llm.invoke(prompt, **kwargs)
5454
return _helper.helper.extract_sql(llm_response)
5555

@@ -176,13 +176,28 @@ def __get_ddl_statements(self, connection: any, tables: list[str], question: str
176176
Returns:
177177
list[str]: The list of DDL statements.
178178
"""
179+
# Try vector store first (semantic search - best for finding relevant tables)
180+
vector_ddls = []
181+
try:
182+
vector_ddls = self.vectorstore.retrieve_relevant_ddl(question, **kwargs)
183+
except Exception as e:
184+
log.info(f"Vector store retrieval failed: {e}")
185+
186+
# If vector store returns good results, use them
187+
if vector_ddls and len(vector_ddls) > 0:
188+
return vector_ddls
189+
190+
# Fallback: get all DDLs from database if vector store fails
179191
if tables and connection:
180192
ddl_statements = []
181193
for table_name in tables:
182-
ddl_statements.append(self.database.get_ddl(connection=connection, table_name=table_name))
183-
else:
184-
ddl_statements = self.vectorstore.retrieve_relevant_ddl(question, **kwargs)
185-
return ddl_statements
194+
try:
195+
ddl_statements.append(self.database.get_ddl(connection=connection, table_name=table_name))
196+
except Exception as e:
197+
log.info(f"Failed to get DDL for table {table_name}: {e}")
198+
return ddl_statements
199+
200+
return []
186201

187202
def ask_db(self, connection, question: Union[str, None] = None, table_names: list = None, visualize: bool = False,
188203
**kwargs) -> dict:

mindsql/databases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .idatabase import IDatabase
2+
from .mariadb import MariaDB
23
from .mysql import MySql
34
from .postgres import Postgres
45
from .sqlite import Sqlite

mindsql/databases/mariadb.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
from typing import List
2+
from urllib.parse import urlparse
3+
4+
import mariadb
5+
import pandas as pd
6+
7+
from .._utils import logger
8+
from .._utils.constants import SUCCESSFULLY_CONNECTED_TO_DB_CONSTANT, ERROR_CONNECTING_TO_DB_CONSTANT, \
9+
INVALID_DB_CONNECTION_OBJECT, ERROR_WHILE_RUNNING_QUERY, MARIADB_DB_TABLES_INFO_SCHEMA_QUERY, \
10+
MARIADB_SHOW_DATABASE_QUERY, MARIADB_SHOW_CREATE_TABLE_QUERY, CONNECTION_ESTABLISH_ERROR_CONSTANT
11+
from . import IDatabase
12+
13+
log = logger.init_loggers("MariaDB")
14+
15+
16+
class MariaDB(IDatabase):
17+
def create_connection(self, url: str, **kwargs) -> any:
18+
"""
19+
A method to create a connection with MariaDB database.
20+
21+
Parameters:
22+
url (str): The URL in the format mariadb://username:password@host:port/database_name
23+
**kwargs: Additional keyword arguments for the connection.
24+
25+
Returns:
26+
any: The connection object.
27+
"""
28+
url = urlparse(url)
29+
try:
30+
# Use official MariaDB connector
31+
connection_params = {
32+
'host': url.hostname,
33+
'port': url.port or int(kwargs.get('port', 3306)),
34+
'user': url.username,
35+
'password': url.password,
36+
'database': url.path.lstrip('/') if url.path else None,
37+
'autocommit': True,
38+
}
39+
40+
# Remove None values and add any additional kwargs
41+
connection_params = {k: v for k, v in connection_params.items() if v is not None}
42+
connection_params.update({k: v for k, v in kwargs.items() if k not in ['port']})
43+
44+
conn = mariadb.connect(**connection_params)
45+
46+
log.info(SUCCESSFULLY_CONNECTED_TO_DB_CONSTANT.format("MariaDB"))
47+
return conn
48+
49+
except mariadb.Error as e:
50+
error_msg = str(e)
51+
log.info(ERROR_CONNECTING_TO_DB_CONSTANT.format("MariaDB", error_msg))
52+
return None
53+
54+
def validate_connection(self, connection: any) -> None:
55+
"""
56+
A function that validates if the provided connection is a MariaDB connection.
57+
58+
Parameters:
59+
connection: The connection object for accessing the database.
60+
61+
Raises:
62+
ValueError: If the provided connection is not a MariaDB connection.
63+
64+
Returns:
65+
None
66+
"""
67+
if connection is None:
68+
raise ValueError(CONNECTION_ESTABLISH_ERROR_CONSTANT)
69+
70+
# MariaDB connection validation (using PyMySQL connection)
71+
if not hasattr(connection, 'cursor'):
72+
raise ValueError(INVALID_DB_CONNECTION_OBJECT.format("MariaDB"))
73+
74+
def execute_sql(self, connection, sql: str) -> pd.DataFrame:
75+
"""
76+
A method to execute SQL on the database.
77+
78+
Parameters:
79+
connection (any): The connection object.
80+
sql (str): The SQL to be executed.
81+
82+
Returns:
83+
pd.DataFrame: The result of the SQL query.
84+
"""
85+
try:
86+
self.validate_connection(connection)
87+
cursor = connection.cursor()
88+
cursor.execute(sql)
89+
90+
# For DDL/DML statements (CREATE, INSERT, UPDATE, DELETE), commit and return empty DataFrame
91+
if sql.strip().upper().startswith(('CREATE', 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'ALTER')):
92+
connection.commit()
93+
cursor.close()
94+
return pd.DataFrame()
95+
96+
# For SELECT statements, fetch results
97+
results = cursor.fetchall()
98+
if cursor.description:
99+
column_names = [i[0] for i in cursor.description]
100+
df = pd.DataFrame(results, columns=column_names)
101+
else:
102+
df = pd.DataFrame()
103+
cursor.close()
104+
return df
105+
except mariadb.Error as e:
106+
log.info(ERROR_WHILE_RUNNING_QUERY.format(e))
107+
return pd.DataFrame()
108+
109+
def get_databases(self, connection) -> List[str]:
110+
"""
111+
Get a list of databases from the given connection and SQL query.
112+
113+
Parameters:
114+
connection (object): The connection object for the database.
115+
116+
Returns:
117+
List[str]: A list of unique database names.
118+
"""
119+
try:
120+
self.validate_connection(connection)
121+
df_databases = self.execute_sql(connection=connection, sql=MARIADB_SHOW_DATABASE_QUERY)
122+
except Exception as e:
123+
log.info(e)
124+
return []
125+
126+
return df_databases["Database"].unique().tolist()
127+
128+
def get_table_names(self, connection, database: str) -> pd.DataFrame:
129+
"""
130+
Retrieves the tables from the information schema for the specified database.
131+
132+
Parameters:
133+
connection: The database connection object.
134+
database (str): The name of the database.
135+
136+
Returns:
137+
DataFrame: A pandas DataFrame containing the table names from the information schema.
138+
"""
139+
self.validate_connection(connection)
140+
df_tables = self.execute_sql(connection, MARIADB_DB_TABLES_INFO_SCHEMA_QUERY.format(database))
141+
return df_tables
142+
143+
def get_all_ddls(self, connection, database: str) -> pd.DataFrame:
144+
"""
145+
Get all DDLs from the specified database using the provided connection object.
146+
147+
Parameters:
148+
connection (any): The connection object.
149+
database (str): The name of the database.
150+
151+
Returns:
152+
pd.DataFrame: A pandas DataFrame containing the DDLs for each table in the specified database.
153+
"""
154+
self.validate_connection(connection)
155+
df_tables = self.get_table_names(connection, database)
156+
df_ddl = pd.DataFrame(columns=['Table', 'DDL'])
157+
for index, row in df_tables.iterrows():
158+
# Handle both uppercase and lowercase column names
159+
table_name = row.get('TABLE_NAME') or row.get('table_name')
160+
if table_name:
161+
ddl_df = self.get_ddl(connection, table_name)
162+
df_ddl = df_ddl._append({'Table': table_name, 'DDL': ddl_df}, ignore_index=True)
163+
return df_ddl
164+
165+
def get_ddl(self, connection: any, table_name: str, **kwargs) -> str:
166+
"""
167+
A method to get the DDL for the table.
168+
169+
Parameters:
170+
connection (any): The connection object.
171+
table_name (str): The name of the table.
172+
**kwargs: Additional keyword arguments.
173+
174+
Returns:
175+
str: The DDL for the table.
176+
"""
177+
ddl_df = self.execute_sql(connection, MARIADB_SHOW_CREATE_TABLE_QUERY.format(table_name))
178+
return ddl_df["Create Table"].iloc[0]
179+
180+
def get_dialect(self) -> str:
181+
"""
182+
A method to get the dialect of the database.
183+
184+
Returns:
185+
str: The dialect of the database.
186+
"""
187+
return 'mysql'

mindsql/llms/googlegenai.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@ def __init__(self, config=None):
2222
raise ValueError(GOOGLE_GEN_AI_APIKEY_ERROR)
2323
api_key = config.pop('api_key')
2424
genai.configure(api_key=api_key)
25-
self.model = genai.GenerativeModel('gemini-pro', **config)
25+
26+
# Get model name from config, default to gemini-1.5-flash
27+
model_name = config.pop('model', 'gemini-1.5-flash')
28+
# Store temperature for later use if provided
29+
self.default_temperature = config.pop('temperature', 0.1)
30+
self.model = genai.GenerativeModel(model_name, **config)
2631

2732
def system_message(self, message: str) -> any:
2833
"""
@@ -75,7 +80,7 @@ def invoke(self, prompt, **kwargs) -> str:
7580
if prompt is None or len(prompt) == 0:
7681
raise Exception("Prompt cannot be empty.")
7782

78-
temperature = kwargs.get("temperature", 0.1)
83+
temperature = kwargs.get("temperature", self.default_temperature)
7984
response = self.model.generate_content(prompt,
8085
generation_config=genai.GenerationConfig(temperature=temperature))
8186
return response.text

mindsql/vectorstores/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .chromadb import ChromaDB
33
from .faiss_db import Faiss
44
from .qdrant import Qdrant
5+
from .mariadb_vector import MariaDBVectorStore

mindsql/vectorstores/chromadb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from . import IVectorstore
1212

13-
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="WhereIsAI/UAE-Large-V1")
13+
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
1414

1515

1616
class ChromaDB(IVectorstore):

0 commit comments

Comments
 (0)