Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mindsql/_helper/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def log_and_return(extracted_sql: str) -> str:
log.info(LOG_AND_RETURN_CONSTANT.format(llm_response, extracted_sql))
return extracted_sql

if "SQLQuery:" in llm_response:
sql_part = llm_response.split("SQLQuery:", 1)[1].strip()
if "\n\n" in sql_part:
sql_part = sql_part.split("\n\n")[0].strip()
return log_and_return(sql_part)

sql_match = re.search(r"```(sql)?\n(.+?)```", llm_response, re.DOTALL)
if sql_match:
return log_and_return(sql_match.group(2).replace("`", ""))
Expand Down
3 changes: 3 additions & 0 deletions mindsql/_utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
MYSQL_SHOW_DATABASE_QUERY = "SHOW DATABASES;"
MYSQL_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = '{}';"
MYSQL_SHOW_CREATE_TABLE_QUERY = "SHOW CREATE TABLE `{}`;"
MARIADB_SHOW_DATABASE_QUERY = "SHOW DATABASES;"
MARIADB_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = '{}';"
MARIADB_SHOW_CREATE_TABLE_QUERY = "SHOW CREATE TABLE `{}`;"
POSTGRESQL_SHOW_DATABASE_QUERY = "SELECT datname as DATABASE_NAME FROM pg_database WHERE datistemplate = false;"
POSTGRESQL_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_catalog = '{db}';"
ERROR_DOWNLOADING_SQLITE_DB_CONSTANT = "Error downloading sqlite db: {}"
Expand Down
9 changes: 9 additions & 0 deletions mindsql/core/mindsql_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,15 @@ def __get_ddl_statements(self, connection: any, tables: list[str], question: str
Returns:
list[str]: The list of DDL statements.
"""
vector_ddls = []
try:
vector_ddls = self.vectorstore.retrieve_relevant_ddl(question, **kwargs)
except Exception as e:
log.info(f"Vector store retrieval failed: {e}")

if vector_ddls and len(vector_ddls) > 0:
return vector_ddls

if tables and connection:
ddl_statements = []
for table_name in tables:
Expand Down
1 change: 1 addition & 0 deletions mindsql/databases/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .idatabase import IDatabase
from .mariadb import MariaDB
from .mysql import MySql
from .postgres import Postgres
from .sqlite import Sqlite
Expand Down
99 changes: 99 additions & 0 deletions mindsql/databases/mariadb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import List
from urllib.parse import urlparse

import mariadb
import pandas as pd

from .._utils import logger
from .._utils.constants import SUCCESSFULLY_CONNECTED_TO_DB_CONSTANT, ERROR_CONNECTING_TO_DB_CONSTANT, \
INVALID_DB_CONNECTION_OBJECT, ERROR_WHILE_RUNNING_QUERY, MARIADB_DB_TABLES_INFO_SCHEMA_QUERY, \
MARIADB_SHOW_DATABASE_QUERY, MARIADB_SHOW_CREATE_TABLE_QUERY, CONNECTION_ESTABLISH_ERROR_CONSTANT
from . import IDatabase

log = logger.init_loggers("MariaDB")


class MariaDB(IDatabase):
def create_connection(self, url: str, **kwargs) -> any:
url = urlparse(url)
try:
connection_params = {
'host': url.hostname,
'port': url.port or int(kwargs.get('port', 3306)),
'user': url.username,
'password': url.password,
'database': url.path.lstrip('/') if url.path else None,
'autocommit': True,
}

connection_params = {k: v for k, v in connection_params.items() if v is not None}
connection_params.update({k: v for k, v in kwargs.items() if k not in ['port']})

conn = mariadb.connect(**connection_params)
log.info(SUCCESSFULLY_CONNECTED_TO_DB_CONSTANT.format("MariaDB"))
return conn

except mariadb.Error as e:
log.info(ERROR_CONNECTING_TO_DB_CONSTANT.format("MariaDB", str(e)))
return None

def validate_connection(self, connection: any) -> None:
if connection is None:
raise ValueError(CONNECTION_ESTABLISH_ERROR_CONSTANT)
if not hasattr(connection, 'cursor'):
raise ValueError(INVALID_DB_CONNECTION_OBJECT.format("MariaDB"))

def execute_sql(self, connection, sql: str) -> pd.DataFrame:
try:
self.validate_connection(connection)
cursor = connection.cursor()
cursor.execute(sql)

if sql.strip().upper().startswith(('CREATE', 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'ALTER')):
connection.commit()
cursor.close()
return pd.DataFrame()

results = cursor.fetchall()
if cursor.description:
column_names = [i[0] for i in cursor.description]
df = pd.DataFrame(results, columns=column_names)
else:
df = pd.DataFrame()
cursor.close()
return df
except mariadb.Error as e:
log.info(ERROR_WHILE_RUNNING_QUERY.format(e))
return pd.DataFrame()

def get_databases(self, connection) -> List[str]:
try:
self.validate_connection(connection)
df_databases = self.execute_sql(connection=connection, sql=MARIADB_SHOW_DATABASE_QUERY)
except Exception as e:
log.info(e)
return []
return df_databases["Database"].unique().tolist()

def get_table_names(self, connection, database: str) -> pd.DataFrame:
self.validate_connection(connection)
df_tables = self.execute_sql(connection, MARIADB_DB_TABLES_INFO_SCHEMA_QUERY.format(database))
return df_tables

def get_all_ddls(self, connection, database: str) -> pd.DataFrame:
self.validate_connection(connection)
df_tables = self.get_table_names(connection, database)
df_ddl = pd.DataFrame(columns=['Table', 'DDL'])
for index, row in df_tables.iterrows():
table_name = row.get('TABLE_NAME') or row.get('table_name')
if table_name:
ddl_df = self.get_ddl(connection, table_name)
df_ddl = df_ddl._append({'Table': table_name, 'DDL': ddl_df}, ignore_index=True)
return df_ddl

def get_ddl(self, connection: any, table_name: str, **kwargs) -> str:
ddl_df = self.execute_sql(connection, MARIADB_SHOW_CREATE_TABLE_QUERY.format(table_name))
return ddl_df["Create Table"].iloc[0]

def get_dialect(self) -> str:
return 'mysql'
1 change: 1 addition & 0 deletions mindsql/vectorstores/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .chromadb import ChromaDB
from .faiss_db import Faiss
from .qdrant import Qdrant
from .mariadb_vector import MariaDBVectorStore
211 changes: 211 additions & 0 deletions mindsql/vectorstores/mariadb_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import json
import uuid
from typing import List
import mariadb
import pandas as pd
from sentence_transformers import SentenceTransformer
from mindsql.vectorstores import IVectorstore


class MariaDBVectorStore(IVectorstore):
def __init__(self, config=None):
if config is None:
raise ValueError("MariaDB configuration is required")

self.collection_name = config.get('collection_name', 'mindsql_vectors')
self.connection_params = {
'host': config.get('host', 'localhost'),
'port': config.get('port', 3306),
'user': config.get('user'),
'password': config.get('password'),
}

if 'database' in config and config['database']:
self.connection_params['database'] = config['database']

self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
self.dimension = 384
self._init_database()

def _init_database(self):
try:
conn = mariadb.connect(**self.connection_params)
cursor = conn.cursor()

cursor.execute(f"DROP TABLE IF EXISTS {self.collection_name}")
cursor.execute(f"DROP TABLE IF EXISTS {self.collection_name}_sql_pairs")

cursor.execute(f"""
CREATE TABLE {self.collection_name} (
id VARCHAR(36) PRIMARY KEY,
document TEXT NOT NULL,
embedding VECTOR({self.dimension}) NOT NULL,
metadata JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_created_at (created_at),
FULLTEXT(document)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
""")

cursor.execute(f"""
CREATE TABLE {self.collection_name}_sql_pairs (
id VARCHAR(36) PRIMARY KEY,
question TEXT NOT NULL,
sql_query TEXT NOT NULL,
embedding VECTOR({self.dimension}) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FULLTEXT(question, sql_query)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
""")

conn.commit()
cursor.close()
conn.close()
except Exception as e:
raise RuntimeError(f"Failed to initialize MariaDB vector store: {e}")

def _format_vector_for_insertion(self, embedding_array):
if len(embedding_array) != self.dimension:
raise ValueError(f"Expected {self.dimension} dimensions, got {len(embedding_array)}")
return '[' + ','.join(f'{float(x)}' for x in embedding_array) + ']'

def add_ddl(self, ddl: str):
embedding = self.embedding_model.encode(ddl).tolist()
vector_json = self._format_vector_for_insertion(embedding)

conn = mariadb.connect(**self.connection_params)
cursor = conn.cursor()
doc_id = str(uuid.uuid4())
cursor.execute(f"""
INSERT INTO {self.collection_name}
(id, document, embedding, metadata)
VALUES (?, ?, VEC_FromText(?), ?)
""", (doc_id, ddl, vector_json, json.dumps({"type": "ddl"})))
conn.commit()
cursor.close()
conn.close()

def add_question_sql(self, question: str, sql: str):
embedding = self.embedding_model.encode(question).tolist()
vector_json = self._format_vector_for_insertion(embedding)

conn = mariadb.connect(**self.connection_params)
cursor = conn.cursor()
doc_id = str(uuid.uuid4())
cursor.execute(f"""
INSERT INTO {self.collection_name}_sql_pairs
(id, question, sql_query, embedding)
VALUES (?, ?, ?, VEC_FromText(?))
""", (doc_id, question, sql, vector_json))
conn.commit()
cursor.close()
conn.close()

def get_similar_question_sql(self, question: str, n_results: int = 5):
conn = mariadb.connect(**self.connection_params)
cursor = conn.cursor()
cursor.execute(f"""
SELECT question, sql_query,
MATCH(question, sql_query) AGAINST (? IN NATURAL LANGUAGE MODE) as text_score
FROM {self.collection_name}_sql_pairs
WHERE MATCH(question, sql_query) AGAINST (? IN NATURAL LANGUAGE MODE)
ORDER BY text_score DESC LIMIT ?
""", (question, question, n_results))
results = cursor.fetchall()
cursor.close()
conn.close()

return [{'question': r[0], 'sql': r[1], 'similarity': r[2], 'text_score': r[2]} for r in results]

def retrieve_relevant_ddl(self, question: str, **kwargs) -> list:
conn = mariadb.connect(**self.connection_params)
cursor = conn.cursor()
cursor.execute(f"""
SELECT document FROM {self.collection_name}
WHERE JSON_EXTRACT(metadata, '$.type') = 'ddl'
ORDER BY created_at DESC LIMIT ?
""", (kwargs.get('n_results', 5),))
results = cursor.fetchall()
cursor.close()
conn.close()
return [row[0] for row in results]

def retrieve_relevant_documentation(self, question: str, **kwargs) -> list:
conn = mariadb.connect(**self.connection_params)
cursor = conn.cursor()
cursor.execute(f"""
SELECT document FROM {self.collection_name}
WHERE JSON_EXTRACT(metadata, '$.type') = 'documentation'
ORDER BY created_at DESC LIMIT ?
""", (kwargs.get('n_results', 5),))
results = cursor.fetchall()
cursor.close()
conn.close()
return [row[0] for row in results]

def retrieve_relevant_question_sql(self, question: str, **kwargs) -> list:
return self.get_similar_question_sql(question, kwargs.get('n_results', 3))

def index_question_sql(self, question: str, sql: str, **kwargs) -> str:
try:
self.add_question_sql(question, sql)
return "Successfully added question-SQL pair"
except Exception as e:
return f"Failed: {e}"

def index_ddl(self, ddl: str, **kwargs) -> str:
try:
self.add_ddl(ddl)
return "Successfully added DDL"
except Exception as e:
return f"Failed: {e}"

def index_documentation(self, documentation: str, **kwargs) -> str:
try:
embedding = self.embedding_model.encode(documentation).tolist()
vector_json = self._format_vector_for_insertion(embedding)

conn = mariadb.connect(**self.connection_params)
cursor = conn.cursor()
doc_id = str(uuid.uuid4())
cursor.execute(f"""
INSERT INTO {self.collection_name}
(id, document, embedding, metadata)
VALUES (?, ?, VEC_FromText(?), ?)
""", (doc_id, documentation, vector_json, json.dumps({"type": "documentation"})))
conn.commit()
cursor.close()
conn.close()
return "Successfully added documentation"
except Exception as e:
return f"Failed: {e}"

def fetch_all_vectorstore_data(self, **kwargs) -> pd.DataFrame:
conn = mariadb.connect(**self.connection_params)
main_df = pd.read_sql(f"SELECT id, document, created_at FROM {self.collection_name}", conn)
sql_pairs_df = pd.read_sql(f"SELECT id, question, sql_query, created_at FROM {self.collection_name}_sql_pairs", conn)
conn.close()

data = []
for _, row in main_df.iterrows():
data.append({'id': row['id'], 'content': row['document'], 'type': 'document', 'created_at': row['created_at']})
for _, row in sql_pairs_df.iterrows():
data.append({'id': row['id'], 'content': f"Q: {row['question']} | SQL: {row['sql_query']}",
'type': 'question_sql', 'created_at': row['created_at']})
return pd.DataFrame(data)

def delete_vectorstore_data(self, item_id: str, **kwargs) -> bool:
conn = mariadb.connect(**self.connection_params)
cursor = conn.cursor()
cursor.execute(f"DELETE FROM {self.collection_name} WHERE id = ?", (item_id,))
main_deleted = cursor.rowcount
cursor.execute(f"DELETE FROM {self.collection_name}_sql_pairs WHERE id = ?", (item_id,))
pairs_deleted = cursor.rowcount
conn.commit()
cursor.close()
conn.close()
return (main_deleted + pairs_deleted) > 0

def add_documents(self, documents: List[str]):
for doc in documents:
self.add_ddl(doc)