diff --git a/mindsql/_helper/helper.py b/mindsql/_helper/helper.py index e4639c2..a9e2f7e 100644 --- a/mindsql/_helper/helper.py +++ b/mindsql/_helper/helper.py @@ -62,6 +62,12 @@ def log_and_return(extracted_sql: str) -> str: log.info(LOG_AND_RETURN_CONSTANT.format(llm_response, extracted_sql)) return extracted_sql + if "SQLQuery:" in llm_response: + sql_part = llm_response.split("SQLQuery:", 1)[1].strip() + if "\n\n" in sql_part: + sql_part = sql_part.split("\n\n")[0].strip() + return log_and_return(sql_part) + sql_match = re.search(r"```(sql)?\n(.+?)```", llm_response, re.DOTALL) if sql_match: return log_and_return(sql_match.group(2).replace("`", "")) diff --git a/mindsql/_utils/constants.py b/mindsql/_utils/constants.py index 66a7408..c9320e1 100644 --- a/mindsql/_utils/constants.py +++ b/mindsql/_utils/constants.py @@ -18,6 +18,9 @@ MYSQL_SHOW_DATABASE_QUERY = "SHOW DATABASES;" MYSQL_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = '{}';" MYSQL_SHOW_CREATE_TABLE_QUERY = "SHOW CREATE TABLE `{}`;" +MARIADB_SHOW_DATABASE_QUERY = "SHOW DATABASES;" +MARIADB_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = '{}';" +MARIADB_SHOW_CREATE_TABLE_QUERY = "SHOW CREATE TABLE `{}`;" POSTGRESQL_SHOW_DATABASE_QUERY = "SELECT datname as DATABASE_NAME FROM pg_database WHERE datistemplate = false;" POSTGRESQL_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_catalog = '{db}';" ERROR_DOWNLOADING_SQLITE_DB_CONSTANT = "Error downloading sqlite db: {}" diff --git a/mindsql/core/mindsql_core.py b/mindsql/core/mindsql_core.py index 8ba32e1..0dbe5e9 100644 --- a/mindsql/core/mindsql_core.py +++ b/mindsql/core/mindsql_core.py @@ -176,6 +176,15 @@ def __get_ddl_statements(self, connection: any, tables: list[str], question: str Returns: list[str]: The list of DDL statements. """ + vector_ddls = [] + try: + vector_ddls = self.vectorstore.retrieve_relevant_ddl(question, **kwargs) + except Exception as e: + log.info(f"Vector store retrieval failed: {e}") + + if vector_ddls and len(vector_ddls) > 0: + return vector_ddls + if tables and connection: ddl_statements = [] for table_name in tables: diff --git a/mindsql/databases/__init__.py b/mindsql/databases/__init__.py index 0034303..1a44573 100644 --- a/mindsql/databases/__init__.py +++ b/mindsql/databases/__init__.py @@ -1,4 +1,5 @@ from .idatabase import IDatabase +from .mariadb import MariaDB from .mysql import MySql from .postgres import Postgres from .sqlite import Sqlite diff --git a/mindsql/databases/mariadb.py b/mindsql/databases/mariadb.py new file mode 100644 index 0000000..5e52529 --- /dev/null +++ b/mindsql/databases/mariadb.py @@ -0,0 +1,99 @@ +from typing import List +from urllib.parse import urlparse + +import mariadb +import pandas as pd + +from .._utils import logger +from .._utils.constants import SUCCESSFULLY_CONNECTED_TO_DB_CONSTANT, ERROR_CONNECTING_TO_DB_CONSTANT, \ + INVALID_DB_CONNECTION_OBJECT, ERROR_WHILE_RUNNING_QUERY, MARIADB_DB_TABLES_INFO_SCHEMA_QUERY, \ + MARIADB_SHOW_DATABASE_QUERY, MARIADB_SHOW_CREATE_TABLE_QUERY, CONNECTION_ESTABLISH_ERROR_CONSTANT +from . import IDatabase + +log = logger.init_loggers("MariaDB") + + +class MariaDB(IDatabase): + def create_connection(self, url: str, **kwargs) -> any: + url = urlparse(url) + try: + connection_params = { + 'host': url.hostname, + 'port': url.port or int(kwargs.get('port', 3306)), + 'user': url.username, + 'password': url.password, + 'database': url.path.lstrip('/') if url.path else None, + 'autocommit': True, + } + + connection_params = {k: v for k, v in connection_params.items() if v is not None} + connection_params.update({k: v for k, v in kwargs.items() if k not in ['port']}) + + conn = mariadb.connect(**connection_params) + log.info(SUCCESSFULLY_CONNECTED_TO_DB_CONSTANT.format("MariaDB")) + return conn + + except mariadb.Error as e: + log.info(ERROR_CONNECTING_TO_DB_CONSTANT.format("MariaDB", str(e))) + return None + + def validate_connection(self, connection: any) -> None: + if connection is None: + raise ValueError(CONNECTION_ESTABLISH_ERROR_CONSTANT) + if not hasattr(connection, 'cursor'): + raise ValueError(INVALID_DB_CONNECTION_OBJECT.format("MariaDB")) + + def execute_sql(self, connection, sql: str) -> pd.DataFrame: + try: + self.validate_connection(connection) + cursor = connection.cursor() + cursor.execute(sql) + + if sql.strip().upper().startswith(('CREATE', 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'ALTER')): + connection.commit() + cursor.close() + return pd.DataFrame() + + results = cursor.fetchall() + if cursor.description: + column_names = [i[0] for i in cursor.description] + df = pd.DataFrame(results, columns=column_names) + else: + df = pd.DataFrame() + cursor.close() + return df + except mariadb.Error as e: + log.info(ERROR_WHILE_RUNNING_QUERY.format(e)) + return pd.DataFrame() + + def get_databases(self, connection) -> List[str]: + try: + self.validate_connection(connection) + df_databases = self.execute_sql(connection=connection, sql=MARIADB_SHOW_DATABASE_QUERY) + except Exception as e: + log.info(e) + return [] + return df_databases["Database"].unique().tolist() + + def get_table_names(self, connection, database: str) -> pd.DataFrame: + self.validate_connection(connection) + df_tables = self.execute_sql(connection, MARIADB_DB_TABLES_INFO_SCHEMA_QUERY.format(database)) + return df_tables + + def get_all_ddls(self, connection, database: str) -> pd.DataFrame: + self.validate_connection(connection) + df_tables = self.get_table_names(connection, database) + df_ddl = pd.DataFrame(columns=['Table', 'DDL']) + for index, row in df_tables.iterrows(): + table_name = row.get('TABLE_NAME') or row.get('table_name') + if table_name: + ddl_df = self.get_ddl(connection, table_name) + df_ddl = df_ddl._append({'Table': table_name, 'DDL': ddl_df}, ignore_index=True) + return df_ddl + + def get_ddl(self, connection: any, table_name: str, **kwargs) -> str: + ddl_df = self.execute_sql(connection, MARIADB_SHOW_CREATE_TABLE_QUERY.format(table_name)) + return ddl_df["Create Table"].iloc[0] + + def get_dialect(self) -> str: + return 'mysql' diff --git a/mindsql/vectorstores/__init__.py b/mindsql/vectorstores/__init__.py index ad17496..0fe3f90 100644 --- a/mindsql/vectorstores/__init__.py +++ b/mindsql/vectorstores/__init__.py @@ -2,3 +2,4 @@ from .chromadb import ChromaDB from .faiss_db import Faiss from .qdrant import Qdrant +from .mariadb_vector import MariaDBVectorStore diff --git a/mindsql/vectorstores/mariadb_vector.py b/mindsql/vectorstores/mariadb_vector.py new file mode 100644 index 0000000..86e0b64 --- /dev/null +++ b/mindsql/vectorstores/mariadb_vector.py @@ -0,0 +1,211 @@ +import json +import uuid +from typing import List +import mariadb +import pandas as pd +from sentence_transformers import SentenceTransformer +from mindsql.vectorstores import IVectorstore + + +class MariaDBVectorStore(IVectorstore): + def __init__(self, config=None): + if config is None: + raise ValueError("MariaDB configuration is required") + + self.collection_name = config.get('collection_name', 'mindsql_vectors') + self.connection_params = { + 'host': config.get('host', 'localhost'), + 'port': config.get('port', 3306), + 'user': config.get('user'), + 'password': config.get('password'), + } + + if 'database' in config and config['database']: + self.connection_params['database'] = config['database'] + + self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2") + self.dimension = 384 + self._init_database() + + def _init_database(self): + try: + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + + cursor.execute(f"DROP TABLE IF EXISTS {self.collection_name}") + cursor.execute(f"DROP TABLE IF EXISTS {self.collection_name}_sql_pairs") + + cursor.execute(f""" + CREATE TABLE {self.collection_name} ( + id VARCHAR(36) PRIMARY KEY, + document TEXT NOT NULL, + embedding VECTOR({self.dimension}) NOT NULL, + metadata JSON, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + INDEX idx_created_at (created_at), + FULLTEXT(document) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """) + + cursor.execute(f""" + CREATE TABLE {self.collection_name}_sql_pairs ( + id VARCHAR(36) PRIMARY KEY, + question TEXT NOT NULL, + sql_query TEXT NOT NULL, + embedding VECTOR({self.dimension}) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FULLTEXT(question, sql_query) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """) + + conn.commit() + cursor.close() + conn.close() + except Exception as e: + raise RuntimeError(f"Failed to initialize MariaDB vector store: {e}") + + def _format_vector_for_insertion(self, embedding_array): + if len(embedding_array) != self.dimension: + raise ValueError(f"Expected {self.dimension} dimensions, got {len(embedding_array)}") + return '[' + ','.join(f'{float(x)}' for x in embedding_array) + ']' + + def add_ddl(self, ddl: str): + embedding = self.embedding_model.encode(ddl).tolist() + vector_json = self._format_vector_for_insertion(embedding) + + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + doc_id = str(uuid.uuid4()) + cursor.execute(f""" + INSERT INTO {self.collection_name} + (id, document, embedding, metadata) + VALUES (?, ?, VEC_FromText(?), ?) + """, (doc_id, ddl, vector_json, json.dumps({"type": "ddl"}))) + conn.commit() + cursor.close() + conn.close() + + def add_question_sql(self, question: str, sql: str): + embedding = self.embedding_model.encode(question).tolist() + vector_json = self._format_vector_for_insertion(embedding) + + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + doc_id = str(uuid.uuid4()) + cursor.execute(f""" + INSERT INTO {self.collection_name}_sql_pairs + (id, question, sql_query, embedding) + VALUES (?, ?, ?, VEC_FromText(?)) + """, (doc_id, question, sql, vector_json)) + conn.commit() + cursor.close() + conn.close() + + def get_similar_question_sql(self, question: str, n_results: int = 5): + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + cursor.execute(f""" + SELECT question, sql_query, + MATCH(question, sql_query) AGAINST (? IN NATURAL LANGUAGE MODE) as text_score + FROM {self.collection_name}_sql_pairs + WHERE MATCH(question, sql_query) AGAINST (? IN NATURAL LANGUAGE MODE) + ORDER BY text_score DESC LIMIT ? + """, (question, question, n_results)) + results = cursor.fetchall() + cursor.close() + conn.close() + + return [{'question': r[0], 'sql': r[1], 'similarity': r[2], 'text_score': r[2]} for r in results] + + def retrieve_relevant_ddl(self, question: str, **kwargs) -> list: + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + cursor.execute(f""" + SELECT document FROM {self.collection_name} + WHERE JSON_EXTRACT(metadata, '$.type') = 'ddl' + ORDER BY created_at DESC LIMIT ? + """, (kwargs.get('n_results', 5),)) + results = cursor.fetchall() + cursor.close() + conn.close() + return [row[0] for row in results] + + def retrieve_relevant_documentation(self, question: str, **kwargs) -> list: + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + cursor.execute(f""" + SELECT document FROM {self.collection_name} + WHERE JSON_EXTRACT(metadata, '$.type') = 'documentation' + ORDER BY created_at DESC LIMIT ? + """, (kwargs.get('n_results', 5),)) + results = cursor.fetchall() + cursor.close() + conn.close() + return [row[0] for row in results] + + def retrieve_relevant_question_sql(self, question: str, **kwargs) -> list: + return self.get_similar_question_sql(question, kwargs.get('n_results', 3)) + + def index_question_sql(self, question: str, sql: str, **kwargs) -> str: + try: + self.add_question_sql(question, sql) + return "Successfully added question-SQL pair" + except Exception as e: + return f"Failed: {e}" + + def index_ddl(self, ddl: str, **kwargs) -> str: + try: + self.add_ddl(ddl) + return "Successfully added DDL" + except Exception as e: + return f"Failed: {e}" + + def index_documentation(self, documentation: str, **kwargs) -> str: + try: + embedding = self.embedding_model.encode(documentation).tolist() + vector_json = self._format_vector_for_insertion(embedding) + + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + doc_id = str(uuid.uuid4()) + cursor.execute(f""" + INSERT INTO {self.collection_name} + (id, document, embedding, metadata) + VALUES (?, ?, VEC_FromText(?), ?) + """, (doc_id, documentation, vector_json, json.dumps({"type": "documentation"}))) + conn.commit() + cursor.close() + conn.close() + return "Successfully added documentation" + except Exception as e: + return f"Failed: {e}" + + def fetch_all_vectorstore_data(self, **kwargs) -> pd.DataFrame: + conn = mariadb.connect(**self.connection_params) + main_df = pd.read_sql(f"SELECT id, document, created_at FROM {self.collection_name}", conn) + sql_pairs_df = pd.read_sql(f"SELECT id, question, sql_query, created_at FROM {self.collection_name}_sql_pairs", conn) + conn.close() + + data = [] + for _, row in main_df.iterrows(): + data.append({'id': row['id'], 'content': row['document'], 'type': 'document', 'created_at': row['created_at']}) + for _, row in sql_pairs_df.iterrows(): + data.append({'id': row['id'], 'content': f"Q: {row['question']} | SQL: {row['sql_query']}", + 'type': 'question_sql', 'created_at': row['created_at']}) + return pd.DataFrame(data) + + def delete_vectorstore_data(self, item_id: str, **kwargs) -> bool: + conn = mariadb.connect(**self.connection_params) + cursor = conn.cursor() + cursor.execute(f"DELETE FROM {self.collection_name} WHERE id = ?", (item_id,)) + main_deleted = cursor.rowcount + cursor.execute(f"DELETE FROM {self.collection_name}_sql_pairs WHERE id = ?", (item_id,)) + pairs_deleted = cursor.rowcount + conn.commit() + cursor.close() + conn.close() + return (main_deleted + pairs_deleted) > 0 + + def add_documents(self, documents: List[str]): + for doc in documents: + self.add_ddl(doc)