From 2c2faaee533bd288a343112e5deaebbb8502db59 Mon Sep 17 00:00:00 2001 From: imash Date: Sat, 1 Nov 2025 11:03:29 +0530 Subject: [PATCH] Add MariaDB integration with native VECTOR support Implement MariaDB database connector and vector store with native VECTOR(384) data type support for efficient similarity search. Includes improvements to SQL extraction and DDL retrieval. --- mindsql/_helper/helper.py | 6 + mindsql/_utils/constants.py | 3 + mindsql/core/mindsql_core.py | 9 ++ mindsql/databases/__init__.py | 1 + mindsql/databases/mariadb.py | 99 ++++++++++++ mindsql/vectorstores/__init__.py | 1 + mindsql/vectorstores/mariadb_vector.py | 211 +++++++++++++++++++++++++ 7 files changed, 330 insertions(+) create mode 100644 mindsql/databases/mariadb.py create mode 100644 mindsql/vectorstores/mariadb_vector.py diff --git a/mindsql/_helper/helper.py b/mindsql/_helper/helper.py index e4639c2..a9e2f7e 100644 --- a/mindsql/_helper/helper.py +++ b/mindsql/_helper/helper.py @@ -62,6 +62,12 @@ def log_and_return(extracted_sql: str) -> str: log.info(LOG_AND_RETURN_CONSTANT.format(llm_response, extracted_sql)) return extracted_sql + if "SQLQuery:" in llm_response: + sql_part = llm_response.split("SQLQuery:", 1)[1].strip() + if "\n\n" in sql_part: + sql_part = sql_part.split("\n\n")[0].strip() + return log_and_return(sql_part) + sql_match = re.search(r"```(sql)?\n(.+?)```", llm_response, re.DOTALL) if sql_match: return log_and_return(sql_match.group(2).replace("`", "")) diff --git a/mindsql/_utils/constants.py b/mindsql/_utils/constants.py index 66a7408..c9320e1 100644 --- a/mindsql/_utils/constants.py +++ b/mindsql/_utils/constants.py @@ -18,6 +18,9 @@ MYSQL_SHOW_DATABASE_QUERY = "SHOW DATABASES;" MYSQL_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = '{}';" MYSQL_SHOW_CREATE_TABLE_QUERY = "SHOW CREATE TABLE `{}`;" +MARIADB_SHOW_DATABASE_QUERY = "SHOW DATABASES;" +MARIADB_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = '{}';" +MARIADB_SHOW_CREATE_TABLE_QUERY = "SHOW CREATE TABLE `{}`;" POSTGRESQL_SHOW_DATABASE_QUERY = "SELECT datname as DATABASE_NAME FROM pg_database WHERE datistemplate = false;" POSTGRESQL_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_catalog = '{db}';" ERROR_DOWNLOADING_SQLITE_DB_CONSTANT = "Error downloading sqlite db: {}" diff --git a/mindsql/core/mindsql_core.py b/mindsql/core/mindsql_core.py index 8ba32e1..0dbe5e9 100644 --- a/mindsql/core/mindsql_core.py +++ b/mindsql/core/mindsql_core.py @@ -176,6 +176,15 @@ def __get_ddl_statements(self, connection: any, tables: list[str], question: str Returns: list[str]: The list of DDL statements. """ + vector_ddls = [] + try: + vector_ddls = self.vectorstore.retrieve_relevant_ddl(question, **kwargs) + except Exception as e: + log.info(f"Vector store retrieval failed: {e}") + + if vector_ddls and len(vector_ddls) > 0: + return vector_ddls + if tables and connection: ddl_statements = [] for table_name in tables: diff --git a/mindsql/databases/__init__.py b/mindsql/databases/__init__.py index 0034303..1a44573 100644 --- a/mindsql/databases/__init__.py +++ b/mindsql/databases/__init__.py @@ -1,4 +1,5 @@ from .idatabase import IDatabase +from .mariadb import MariaDB from .mysql import MySql from .postgres import Postgres from .sqlite import Sqlite diff --git a/mindsql/databases/mariadb.py b/mindsql/databases/mariadb.py new file mode 100644 index 0000000..5e52529 --- /dev/null +++ b/mindsql/databases/mariadb.py @@ -0,0 +1,99 @@ +from typing import List +from urllib.parse import urlparse + +import mariadb +import pandas as pd + +from .._utils import logger +from .._utils.constants import SUCCESSFULLY_CONNECTED_TO_DB_CONSTANT, ERROR_CONNECTING_TO_DB_CONSTANT, \ + INVALID_DB_CONNECTION_OBJECT, ERROR_WHILE_RUNNING_QUERY, MARIADB_DB_TABLES_INFO_SCHEMA_QUERY, \ + MARIADB_SHOW_DATABASE_QUERY, MARIADB_SHOW_CREATE_TABLE_QUERY, CONNECTION_ESTABLISH_ERROR_CONSTANT +from . import IDatabase + +log = logger.init_loggers("MariaDB") + + +class MariaDB(IDatabase): + def create_connection(self, url: str, **kwargs) -> any: + url = urlparse(url) + try: + connection_params = { + 'host': url.hostname, + 'port': url.port or int(kwargs.get('port', 3306)), + 'user': url.username, + 'password': url.password, + 'database': url.path.lstrip('/') if url.path else None, + 'autocommit': True, + } + + connection_params = {k: v for k, v in connection_params.items() if v is not None} + connection_params.update({k: v for k, v in kwargs.items() if k not in ['port']}) + + conn = mariadb.connect(**connection_params) + log.info(SUCCESSFULLY_CONNECTED_TO_DB_CONSTANT.format("MariaDB")) + return conn + + except mariadb.Error as e: + log.info(ERROR_CONNECTING_TO_DB_CONSTANT.format("MariaDB", str(e))) + return None + + def validate_connection(self, connection: any) -> None: + if connection is None: + raise ValueError(CONNECTION_ESTABLISH_ERROR_CONSTANT) + if not hasattr(connection, 'cursor'): + raise ValueError(INVALID_DB_CONNECTION_OBJECT.format("MariaDB")) + + def execute_sql(self, connection, sql: str) -> pd.DataFrame: + try: + self.validate_connection(connection) + cursor = connection.cursor() + cursor.execute(sql) + + if sql.strip().upper().startswith(('CREATE', 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'ALTER')): + connection.commit() + cursor.close() + return pd.DataFrame() + + results = cursor.fetchall() + if cursor.description: + column_names = [i[0] for i in cursor.description] + df = pd.DataFrame(results, columns=column_names) + else: + df = pd.DataFrame() + cursor.close() + return df + except mariadb.Error as e: + log.info(ERROR_WHILE_RUNNING_QUERY.format(e)) + return pd.DataFrame() + + def get_databases(self, connection) -> List[str]: + try: + self.validate_connection(connection) + df_databases = self.execute_sql(connection=connection, sql=MARIADB_SHOW_DATABASE_QUERY) + except Exception as e: + log.info(e) + return [] + return df_databases["Database"].unique().tolist() + + def get_table_names(self, connection, database: str) -> pd.DataFrame: + self.validate_connection(connection) + df_tables = self.execute_sql(connection, MARIADB_DB_TABLES_INFO_SCHEMA_QUERY.format(database)) + return df_tables + + def get_all_ddls(self, connection, database: str) -> pd.DataFrame: + self.validate_connection(connection) + df_tables = self.get_table_names(connection, database) + df_ddl = pd.DataFrame(columns=['Table', 'DDL']) + for index, row in df_tables.iterrows(): + table_name = row.get('TABLE_NAME') or row.get('table_name') + if table_name: + ddl_df = self.get_ddl(connection, table_name) + df_ddl = df_ddl._append({'Table': table_name, 'DDL': ddl_df}, ignore_index=True) + return df_ddl + + def get_ddl(self, connection: any, table_name: str, **kwargs) -> str: + ddl_df = self.execute_sql(connection, MARIADB_SHOW_CREATE_TABLE_QUERY.format(table_name)) + return ddl_df["Create Table"].iloc[0] + + def get_dialect(self) -> str: + return 'mysql' diff --git a/mindsql/vectorstores/__init__.py b/mindsql/vectorstores/__init__.py index ad17496..0fe3f90 100644 --- a/mindsql/vectorstores/__init__.py +++ b/mindsql/vectorstores/__init__.py @@ -2,3 +2,4 @@ from .chromadb import ChromaDB from .faiss_db import Faiss from .qdrant import Qdrant +from .mariadb_vector import MariaDBVectorStore diff --git a/mindsql/vectorstores/mariadb_vector.py b/mindsql/vectorstores/mariadb_vector.py new file mode 100644 index 0000000..86e0b64 --- /dev/null +++ b/mindsql/vectorstores/mariadb_vector.py @@ -0,0 +1,211 @@ +import json +import uuid +from typing import List +import mariadb +import pandas as pd +from sentence_transformers import SentenceTransformer +from mindsql.vectorstores import IVectorstore + + +class MariaDBVectorStore(IVectorstore): + def __init__(self, config=None): + if config is None: + raise ValueError("MariaDB configuration is required") + + self.collection_name = config.get('collection_name', 'mindsql_vectors') + self.connection_params = { + 'host': config.get('host', 'localhost'), + 'port': config.get('port', 3306), + 'user': config.get('user'), + 'password': config.get('password'), + } + + if 'database' in config and config['database']: + self.connection_params['database'] = config['database'] + + self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2") + self.dimension = 384 + self._init_database() + + def _init_database(self): + try: + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + + cursor.execute(f"DROP TABLE IF EXISTS {self.collection_name}") + cursor.execute(f"DROP TABLE IF EXISTS {self.collection_name}_sql_pairs") + + cursor.execute(f""" + CREATE TABLE {self.collection_name} ( + id VARCHAR(36) PRIMARY KEY, + document TEXT NOT NULL, + embedding VECTOR({self.dimension}) NOT NULL, + metadata JSON, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + INDEX idx_created_at (created_at), + FULLTEXT(document) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """) + + cursor.execute(f""" + CREATE TABLE {self.collection_name}_sql_pairs ( + id VARCHAR(36) PRIMARY KEY, + question TEXT NOT NULL, + sql_query TEXT NOT NULL, + embedding VECTOR({self.dimension}) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FULLTEXT(question, sql_query) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """) + + conn.commit() + cursor.close() + conn.close() + except Exception as e: + raise RuntimeError(f"Failed to initialize MariaDB vector store: {e}") + + def _format_vector_for_insertion(self, embedding_array): + if len(embedding_array) != self.dimension: + raise ValueError(f"Expected {self.dimension} dimensions, got {len(embedding_array)}") + return '[' + ','.join(f'{float(x)}' for x in embedding_array) + ']' + + def add_ddl(self, ddl: str): + embedding = self.embedding_model.encode(ddl).tolist() + vector_json = self._format_vector_for_insertion(embedding) + + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + doc_id = str(uuid.uuid4()) + cursor.execute(f""" + INSERT INTO {self.collection_name} + (id, document, embedding, metadata) + VALUES (?, ?, VEC_FromText(?), ?) + """, (doc_id, ddl, vector_json, json.dumps({"type": "ddl"}))) + conn.commit() + cursor.close() + conn.close() + + def add_question_sql(self, question: str, sql: str): + embedding = self.embedding_model.encode(question).tolist() + vector_json = self._format_vector_for_insertion(embedding) + + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + doc_id = str(uuid.uuid4()) + cursor.execute(f""" + INSERT INTO {self.collection_name}_sql_pairs + (id, question, sql_query, embedding) + VALUES (?, ?, ?, VEC_FromText(?)) + """, (doc_id, question, sql, vector_json)) + conn.commit() + cursor.close() + conn.close() + + def get_similar_question_sql(self, question: str, n_results: int = 5): + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + cursor.execute(f""" + SELECT question, sql_query, + MATCH(question, sql_query) AGAINST (? IN NATURAL LANGUAGE MODE) as text_score + FROM {self.collection_name}_sql_pairs + WHERE MATCH(question, sql_query) AGAINST (? IN NATURAL LANGUAGE MODE) + ORDER BY text_score DESC LIMIT ? + """, (question, question, n_results)) + results = cursor.fetchall() + cursor.close() + conn.close() + + return [{'question': r[0], 'sql': r[1], 'similarity': r[2], 'text_score': r[2]} for r in results] + + def retrieve_relevant_ddl(self, question: str, **kwargs) -> list: + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + cursor.execute(f""" + SELECT document FROM {self.collection_name} + WHERE JSON_EXTRACT(metadata, '$.type') = 'ddl' + ORDER BY created_at DESC LIMIT ? + """, (kwargs.get('n_results', 5),)) + results = cursor.fetchall() + cursor.close() + conn.close() + return [row[0] for row in results] + + def retrieve_relevant_documentation(self, question: str, **kwargs) -> list: + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + cursor.execute(f""" + SELECT document FROM {self.collection_name} + WHERE JSON_EXTRACT(metadata, '$.type') = 'documentation' + ORDER BY created_at DESC LIMIT ? + """, (kwargs.get('n_results', 5),)) + results = cursor.fetchall() + cursor.close() + conn.close() + return [row[0] for row in results] + + def retrieve_relevant_question_sql(self, question: str, **kwargs) -> list: + return self.get_similar_question_sql(question, kwargs.get('n_results', 3)) + + def index_question_sql(self, question: str, sql: str, **kwargs) -> str: + try: + self.add_question_sql(question, sql) + return "Successfully added question-SQL pair" + except Exception as e: + return f"Failed: {e}" + + def index_ddl(self, ddl: str, **kwargs) -> str: + try: + self.add_ddl(ddl) + return "Successfully added DDL" + except Exception as e: + return f"Failed: {e}" + + def index_documentation(self, documentation: str, **kwargs) -> str: + try: + embedding = self.embedding_model.encode(documentation).tolist() + vector_json = self._format_vector_for_insertion(embedding) + + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + doc_id = str(uuid.uuid4()) + cursor.execute(f""" + INSERT INTO {self.collection_name} + (id, document, embedding, metadata) + VALUES (?, ?, VEC_FromText(?), ?) + """, (doc_id, documentation, vector_json, json.dumps({"type": "documentation"}))) + conn.commit() + cursor.close() + conn.close() + return "Successfully added documentation" + except Exception as e: + return f"Failed: {e}" + + def fetch_all_vectorstore_data(self, **kwargs) -> pd.DataFrame: + conn = mariadb.connect(**self.connection_params) + main_df = pd.read_sql(f"SELECT id, document, created_at FROM {self.collection_name}", conn) + sql_pairs_df = pd.read_sql(f"SELECT id, question, sql_query, created_at FROM {self.collection_name}_sql_pairs", conn) + conn.close() + + data = [] + for _, row in main_df.iterrows(): + data.append({'id': row['id'], 'content': row['document'], 'type': 'document', 'created_at': row['created_at']}) + for _, row in sql_pairs_df.iterrows(): + data.append({'id': row['id'], 'content': f"Q: {row['question']} | SQL: {row['sql_query']}", + 'type': 'question_sql', 'created_at': row['created_at']}) + return pd.DataFrame(data) + + def delete_vectorstore_data(self, item_id: str, **kwargs) -> bool: + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + cursor.execute(f"DELETE FROM {self.collection_name} WHERE id = ?", (item_id,)) + main_deleted = cursor.rowcount + cursor.execute(f"DELETE FROM {self.collection_name}_sql_pairs WHERE id = ?", (item_id,)) + pairs_deleted = cursor.rowcount + conn.commit() + cursor.close() + conn.close() + return (main_deleted + pairs_deleted) > 0 + + def add_documents(self, documents: List[str]): + for doc in documents: + self.add_ddl(doc)