Skip to content

Commit 3d0ff02

Browse files
Merge pull request #20 from cyrannano/sqlserver-integration
SQL Server integration
2 parents b95ca73 + 701c046 commit 3d0ff02

File tree

4 files changed

+314
-0
lines changed

4 files changed

+314
-0
lines changed

mindsql/_utils/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,7 @@
3333
PROMPT_EMPTY_EXCEPTION = "Prompt cannot be empty."
3434
POSTGRESQL_SHOW_CREATE_TABLE_QUERY = """SELECT 'CREATE TABLE "' || table_name || '" (' || array_to_string(array_agg(column_name || ' ' || data_type), ', ') || ');' AS create_statement FROM information_schema.columns WHERE table_name = '{table}' GROUP BY table_name;"""
3535
ANTHROPIC_VALUE_ERROR = "Anthropic API key is required"
36+
SQLSERVER_SHOW_DATABASE_QUERY= "SELECT name FROM sys.databases;"
37+
SQLSERVER_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT CONCAT(TABLE_SCHEMA,'.',TABLE_NAME) FROM [{db}].INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE'"
38+
SQLSERVER_SHOW_CREATE_TABLE_QUERY = "DECLARE @TableName NVARCHAR(MAX) = '{table}'; DECLARE @SchemaName NVARCHAR(MAX) = '{schema}'; DECLARE @SQL NVARCHAR(MAX); SELECT @SQL = 'CREATE TABLE ' + @SchemaName + '.' + t.name + ' (' + CHAR(13) + ( SELECT ' ' + c.name + ' ' + UPPER(tp.name) + CASE WHEN tp.name IN ('char', 'varchar', 'nchar', 'nvarchar') THEN '(' + CASE WHEN c.max_length = -1 THEN 'MAX' ELSE CAST(c.max_length AS VARCHAR(10)) END + ')' WHEN tp.name IN ('decimal', 'numeric') THEN '(' + CAST(c.precision AS VARCHAR(10)) + ',' + CAST(c.scale AS VARCHAR(10)) + ')' ELSE '' END + ',' + CHAR(13) FROM sys.columns c JOIN sys.types tp ON c.user_type_id = tp.user_type_id WHERE c.object_id = t.object_id ORDER BY c.column_id FOR XML PATH(''), TYPE ).value('.', 'NVARCHAR(MAX)') + CHAR(13) + ')' FROM sys.tables t JOIN sys.schemas s ON t.schema_id = s.schema_id WHERE t.name = @TableName AND s.name = @SchemaName; SELECT @SQL AS SQLQuery;"
3639
OLLAMA_CONFIG_REQUIRED = "{type} configuration is required."

mindsql/databases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .mysql import MySql
33
from .postgres import Postgres
44
from .sqlite import Sqlite
5+
from .sqlserver import SQLServer

mindsql/databases/sqlserver.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from typing import List, Optional
2+
from urllib.parse import urlparse
3+
4+
import pandas as pd
5+
import pyodbc
6+
7+
from . import IDatabase
8+
from .._utils import logger
9+
from .._utils.constants import ERROR_WHILE_RUNNING_QUERY, ERROR_CONNECTING_TO_DB_CONSTANT, INVALID_DB_CONNECTION_OBJECT, \
10+
CONNECTION_ESTABLISH_ERROR_CONSTANT, SQLSERVER_SHOW_DATABASE_QUERY, SQLSERVER_DB_TABLES_INFO_SCHEMA_QUERY, \
11+
SQLSERVER_SHOW_CREATE_TABLE_QUERY
12+
13+
log = logger.init_loggers("SQL Server")
14+
15+
16+
class SQLServer(IDatabase):
17+
@staticmethod
18+
def create_connection(url: str, **kwargs) -> any:
19+
"""
20+
Connects to a SQL Server database using the provided URL.
21+
22+
Parameters:
23+
- url (str): The connection string to the SQL Server database in the format:
24+
'DRIVER={ODBC Driver 17 for SQL Server};SERVER=server_name;DATABASE=database_name;UID=user;PWD=password'
25+
- **kwargs: Additional keyword arguments for the connection
26+
27+
Returns:
28+
- connection: A connection to the SQL Server database
29+
"""
30+
31+
try:
32+
connection = pyodbc.connect(url, **kwargs)
33+
return connection
34+
except pyodbc.Error as e:
35+
log.error(ERROR_CONNECTING_TO_DB_CONSTANT.format("SQL Server", e))
36+
37+
def execute_sql(self, connection, sql:str) -> Optional[pd.DataFrame]:
38+
"""
39+
A function that runs an SQL query using the provided connection and returns the results as a pandas DataFrame.
40+
41+
Parameters:
42+
connection: The connection object for the database.
43+
sql (str): The SQL query to be executed
44+
45+
Returns:
46+
pd.DataFrame: A DataFrame containing the results of the SQL query.
47+
"""
48+
try:
49+
self.validate_connection(connection)
50+
cursor = connection.cursor()
51+
cursor.execute(sql)
52+
columns = [column[0] for column in cursor.description]
53+
data = cursor.fetchall()
54+
data = [list(row) for row in data]
55+
cursor.close()
56+
return pd.DataFrame(data, columns=columns)
57+
except pyodbc.Error as e:
58+
log.error(ERROR_WHILE_RUNNING_QUERY.format(e))
59+
return None
60+
61+
def get_databases(self, connection) -> List[str]:
62+
"""
63+
Get a list of databases from the given connection and SQL query.
64+
65+
Parameters:
66+
connection: The connection object for the database.
67+
68+
Returns:
69+
List[str]: A list of unique database names.
70+
"""
71+
try:
72+
self.validate_connection(connection)
73+
cursor = connection.cursor()
74+
cursor.execute(SQLSERVER_SHOW_DATABASE_QUERY)
75+
databases = [row[0] for row in cursor.fetchall()]
76+
cursor.close()
77+
return databases
78+
except pyodbc.Error as e:
79+
log.error(ERROR_WHILE_RUNNING_QUERY.format(e))
80+
return []
81+
82+
def get_table_names(self, connection, database: str) -> pd.DataFrame:
83+
"""
84+
Retrieves the tables along with their schema (schema.table_name) from the information schema for the specified
85+
database.
86+
87+
Parameters:
88+
connection: The database connection object.
89+
database (str): The name of the database.
90+
91+
Returns:
92+
DataFrame: A pandas DataFrame containing the table names from the information schema.
93+
"""
94+
self.validate_connection(connection)
95+
query = SQLSERVER_DB_TABLES_INFO_SCHEMA_QUERY.format(db=database)
96+
return self.execute_sql(connection, query)
97+
98+
99+
100+
101+
def get_all_ddls(self, connection: any, database: str) -> pd.DataFrame:
102+
"""
103+
A method to get the DDLs for all the tables in the database.
104+
105+
Parameters:
106+
connection (any): The connection object.
107+
database (str): The name of the database.
108+
109+
Returns:
110+
DataFrame: A pandas DataFrame containing the DDLs for all the tables in the database.
111+
"""
112+
df_tables = self.get_table_names(connection, database)
113+
ddl_df = pd.DataFrame(columns=['Table', 'DDL'])
114+
for index, row in df_tables.iterrows():
115+
ddl = self.get_ddl(connection, row.iloc[0])
116+
ddl_df = ddl_df._append({'Table': row.iloc[0], 'DDL': ddl}, ignore_index=True)
117+
118+
return ddl_df
119+
120+
121+
122+
def validate_connection(self, connection: any) -> None:
123+
"""
124+
A function that validates if the provided connection is a SQL Server connection.
125+
126+
Parameters:
127+
connection: The connection object for accessing the database.
128+
129+
Raises:
130+
ValueError: If the provided connection is not a SQL Server connection.
131+
132+
Returns:
133+
None
134+
"""
135+
if connection is None:
136+
raise ValueError(CONNECTION_ESTABLISH_ERROR_CONSTANT)
137+
if not isinstance(connection, pyodbc.Connection):
138+
raise ValueError(INVALID_DB_CONNECTION_OBJECT.format("SQL Server"))
139+
140+
def get_ddl(self, connection: any, table_name: str, **kwargs) -> str:
141+
schema_name, table_name = table_name.split('.')
142+
query = SQLSERVER_SHOW_CREATE_TABLE_QUERY.format(table=table_name, schema=schema_name)
143+
df_ddl = self.execute_sql(connection, query)
144+
return df_ddl['SQLQuery'][0]
145+
146+
def get_dialect(self) -> str:
147+
return 'tsql'

tests/sqlserver_test.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import unittest
2+
from unittest.mock import patch, MagicMock
3+
import pyodbc
4+
import pandas as pd
5+
from mindsql.databases.sqlserver import SQLServer, ERROR_WHILE_RUNNING_QUERY, ERROR_CONNECTING_TO_DB_CONSTANT, \
6+
INVALID_DB_CONNECTION_OBJECT, CONNECTION_ESTABLISH_ERROR_CONSTANT
7+
from mindsql.databases.sqlserver import log as logger
8+
9+
10+
class TestSQLServer(unittest.TestCase):
11+
12+
@patch('mindsql.databases.sqlserver.pyodbc.connect')
13+
def test_create_connection_success(self, mock_connect):
14+
mock_connect.return_value = MagicMock(spec=pyodbc.Connection)
15+
connection = SQLServer.create_connection(
16+
'DRIVER={ODBC Driver 17 for SQL Server};SERVER=server_name;DATABASE=database_name;UID=user;PWD=password')
17+
self.assertIsInstance(connection, pyodbc.Connection)
18+
19+
@patch('mindsql.databases.sqlserver.pyodbc.connect')
20+
def test_create_connection_failure(self, mock_connect):
21+
mock_connect.side_effect = pyodbc.Error('Connection failed')
22+
with self.assertLogs(logger, level='ERROR') as cm:
23+
connection = SQLServer.create_connection(
24+
'DRIVER={ODBC Driver 17 for SQL Server};SERVER=server_name;DATABASE=database_name;UID=user;PWD=password')
25+
self.assertIsNone(connection)
26+
self.assertTrue(any(
27+
ERROR_CONNECTING_TO_DB_CONSTANT.format("SQL Server", 'Connection failed') in message for message in
28+
cm.output))
29+
30+
@patch('mindsql.databases.sqlserver.pyodbc.connect')
31+
def test_execute_sql_success(self, mock_connect):
32+
# Mock the connection and cursor
33+
mock_connection = MagicMock(spec=pyodbc.Connection)
34+
mock_cursor = MagicMock()
35+
36+
mock_connect.return_value = mock_connection
37+
mock_connection.cursor.return_value = mock_cursor
38+
39+
# Mock cursor behavior
40+
mock_cursor.execute.return_value = None
41+
mock_cursor.description = [('column1',), ('column2',)]
42+
mock_cursor.fetchall.return_value = [(1, 'a'), (2, 'b')]
43+
44+
connection = SQLServer.create_connection('fake_connection_string')
45+
sql = "SELECT * FROM table"
46+
sql_server = SQLServer()
47+
result = sql_server.execute_sql(connection, sql)
48+
expected_df = pd.DataFrame(data=[(1, 'a'), (2, 'b')], columns=['column1', 'column2'])
49+
pd.testing.assert_frame_equal(result, expected_df)
50+
51+
@patch('mindsql.databases.sqlserver.pyodbc.connect')
52+
def test_execute_sql_failure(self, mock_connect):
53+
# Mock the connection and cursor
54+
mock_connection = MagicMock(spec=pyodbc.Connection)
55+
mock_cursor = MagicMock()
56+
57+
mock_connect.return_value = mock_connection
58+
mock_connection.cursor.return_value = mock_cursor
59+
mock_cursor.execute.side_effect = pyodbc.Error('Query failed')
60+
61+
connection = SQLServer.create_connection('fake_connection_string')
62+
sql = "SELECT * FROM table"
63+
sql_server = SQLServer()
64+
65+
with self.assertLogs(logger, level='ERROR') as cm:
66+
result = sql_server.execute_sql(connection, sql)
67+
self.assertIsNone(result)
68+
self.assertTrue(any(ERROR_WHILE_RUNNING_QUERY.format('Query failed') in message for message in cm.output))
69+
70+
@patch('mindsql.databases.sqlserver.pyodbc.connect')
71+
def test_get_databases_success(self, mock_connect):
72+
# Mock the connection and cursor
73+
mock_connection = MagicMock(spec=pyodbc.Connection)
74+
mock_cursor = MagicMock()
75+
76+
mock_connect.return_value = mock_connection
77+
mock_connection.cursor.return_value = mock_cursor
78+
79+
# Mock cursor behavior
80+
mock_cursor.execute.return_value = None
81+
mock_cursor.fetchall.return_value = [('database1',), ('database2',)]
82+
83+
connection = SQLServer.create_connection('fake_connection_string')
84+
sql_server = SQLServer()
85+
result = sql_server.get_databases(connection)
86+
self.assertEqual(result, ['database1', 'database2'])
87+
88+
@patch('mindsql.databases.sqlserver.pyodbc.connect')
89+
def test_get_databases_failure(self, mock_connect):
90+
# Mock the connection and cursor
91+
mock_connection = MagicMock(spec=pyodbc.Connection)
92+
mock_cursor = MagicMock()
93+
94+
mock_connect.return_value = mock_connection
95+
mock_connection.cursor.return_value = mock_cursor
96+
mock_cursor.execute.side_effect = pyodbc.Error('Query failed')
97+
98+
connection = SQLServer.create_connection('fake_connection_string')
99+
sql_server = SQLServer()
100+
101+
with self.assertLogs(logger, level='ERROR') as cm:
102+
result = sql_server.get_databases(connection)
103+
self.assertEqual(result, [])
104+
self.assertTrue(any(ERROR_WHILE_RUNNING_QUERY.format('Query failed') in message for message in cm.output))
105+
106+
@patch('mindsql.databases.sqlserver.SQLServer.execute_sql')
107+
def test_get_table_names_success(self, mock_execute_sql):
108+
mock_execute_sql.return_value = pd.DataFrame(data=[('schema1.table1',), ('schema2.table2',)],
109+
columns=['table_name'])
110+
111+
connection = MagicMock(spec=pyodbc.Connection)
112+
sql_server = SQLServer()
113+
result = sql_server.get_table_names(connection, 'database_name')
114+
expected_df = pd.DataFrame(data=[('schema1.table1',), ('schema2.table2',)], columns=['table_name'])
115+
pd.testing.assert_frame_equal(result, expected_df)
116+
117+
@patch('mindsql.databases.sqlserver.SQLServer.execute_sql')
118+
def test_get_all_ddls_success(self, mock_execute_sql):
119+
mock_execute_sql.side_effect = [
120+
pd.DataFrame(data=[('schema1.table1',)], columns=['table_name']),
121+
pd.DataFrame(data=['CREATE TABLE schema1.table1 (...);'], columns=['SQLQuery'])
122+
]
123+
124+
connection = MagicMock(spec=pyodbc.Connection)
125+
sql_server = SQLServer()
126+
result = sql_server.get_all_ddls(connection, 'database_name')
127+
128+
expected_df = pd.DataFrame(data=[{'Table': 'schema1.table1', 'DDL': 'CREATE TABLE schema1.table1 (...);'}])
129+
pd.testing.assert_frame_equal(result, expected_df)
130+
131+
def test_validate_connection_success(self):
132+
connection = MagicMock(spec=pyodbc.Connection)
133+
sql_server = SQLServer()
134+
# Should not raise any exception
135+
sql_server.validate_connection(connection)
136+
137+
def test_validate_connection_failure(self):
138+
sql_server = SQLServer()
139+
140+
with self.assertRaises(ValueError) as cm:
141+
sql_server.validate_connection(None)
142+
self.assertEqual(str(cm.exception), CONNECTION_ESTABLISH_ERROR_CONSTANT)
143+
144+
with self.assertRaises(ValueError) as cm:
145+
sql_server.validate_connection("InvalidConnectionObject")
146+
self.assertEqual(str(cm.exception), INVALID_DB_CONNECTION_OBJECT.format("SQL Server"))
147+
148+
@patch('mindsql.databases.sqlserver.SQLServer.execute_sql')
149+
def test_get_ddl_success(self, mock_execute_sql):
150+
mock_execute_sql.return_value = pd.DataFrame(data=['CREATE TABLE schema1.table1 (...);'], columns=['SQLQuery'])
151+
152+
connection = MagicMock(spec=pyodbc.Connection)
153+
sql_server = SQLServer()
154+
result = sql_server.get_ddl(connection, 'schema1.table1')
155+
self.assertEqual(result, 'CREATE TABLE schema1.table1 (...);')
156+
157+
def test_get_dialect(self):
158+
sql_server = SQLServer()
159+
self.assertEqual(sql_server.get_dialect(), 'tsql')
160+
161+
162+
if __name__ == '__main__':
163+
unittest.main()

0 commit comments

Comments
 (0)