Skip to content

Commit 701c046

Browse files
authored
Merge branch 'master' into sqlserver-integration
2 parents 845f26d + b95ca73 commit 701c046

File tree

3 files changed

+189
-0
lines changed

3 files changed

+189
-0
lines changed

mindsql/_utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@
3636
SQLSERVER_SHOW_DATABASE_QUERY= "SELECT name FROM sys.databases;"
3737
SQLSERVER_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT CONCAT(TABLE_SCHEMA,'.',TABLE_NAME) FROM [{db}].INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE'"
3838
SQLSERVER_SHOW_CREATE_TABLE_QUERY = "DECLARE @TableName NVARCHAR(MAX) = '{table}'; DECLARE @SchemaName NVARCHAR(MAX) = '{schema}'; DECLARE @SQL NVARCHAR(MAX); SELECT @SQL = 'CREATE TABLE ' + @SchemaName + '.' + t.name + ' (' + CHAR(13) + ( SELECT ' ' + c.name + ' ' + UPPER(tp.name) + CASE WHEN tp.name IN ('char', 'varchar', 'nchar', 'nvarchar') THEN '(' + CASE WHEN c.max_length = -1 THEN 'MAX' ELSE CAST(c.max_length AS VARCHAR(10)) END + ')' WHEN tp.name IN ('decimal', 'numeric') THEN '(' + CAST(c.precision AS VARCHAR(10)) + ',' + CAST(c.scale AS VARCHAR(10)) + ')' ELSE '' END + ',' + CHAR(13) FROM sys.columns c JOIN sys.types tp ON c.user_type_id = tp.user_type_id WHERE c.object_id = t.object_id ORDER BY c.column_id FOR XML PATH(''), TYPE ).value('.', 'NVARCHAR(MAX)') + CHAR(13) + ')' FROM sys.tables t JOIN sys.schemas s ON t.schema_id = s.schema_id WHERE t.name = @TableName AND s.name = @SchemaName; SELECT @SQL AS SQLQuery;"
39+
OLLAMA_CONFIG_REQUIRED = "{type} configuration is required."

mindsql/llms/ollama.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from ollama import Client, Options
2+
3+
from .illm import ILlm
4+
from .._utils.constants import PROMPT_EMPTY_EXCEPTION, OLLAMA_CONFIG_REQUIRED
5+
from .._utils import logger
6+
7+
log = logger.init_loggers("Ollama Client")
8+
9+
10+
class Ollama(ILlm):
11+
def __init__(self, model_config: dict, client_config=None, client: Client = None):
12+
"""
13+
Initialize the class with an optional config parameter.
14+
15+
Parameters:
16+
model_config (dict): The model configuration parameter.
17+
config (dict): The configuration parameter.
18+
client (Client): The client parameter.
19+
20+
Returns:
21+
None
22+
"""
23+
self.client = client
24+
self.client_config = client_config
25+
self.model_config = model_config
26+
27+
if self.client is not None:
28+
if self.client_config is not None:
29+
log.warning("Client object provided. Ignoring client_config parameter.")
30+
return
31+
32+
if client_config is None:
33+
raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Client"))
34+
35+
if model_config is None:
36+
raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Model"))
37+
38+
if 'model' not in model_config:
39+
raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Model name"))
40+
41+
self.client = Client(**client_config)
42+
43+
def system_message(self, message: str) -> any:
44+
"""
45+
Create a system message.
46+
47+
Parameters:
48+
message (str): The message parameter.
49+
50+
Returns:
51+
any
52+
"""
53+
return {"role": "system", "content": message}
54+
55+
def user_message(self, message: str) -> any:
56+
"""
57+
Create a user message.
58+
59+
Parameters:
60+
message (str): The message parameter.
61+
62+
Returns:
63+
any
64+
"""
65+
return {"role": "user", "content": message}
66+
67+
def assistant_message(self, message: str) -> any:
68+
"""
69+
Create an assistant message.
70+
71+
Parameters:
72+
message (str): The message parameter.
73+
74+
Returns:
75+
any
76+
"""
77+
return {"role": "assistant", "content": message}
78+
79+
def invoke(self, prompt, **kwargs) -> str:
80+
"""
81+
Submit a prompt to the model for generating a response.
82+
83+
Parameters:
84+
prompt (str): The prompt parameter.
85+
**kwargs: Additional keyword arguments (optional).
86+
- temperature (float): The temperature parameter for controlling randomness in generation.
87+
88+
Returns:
89+
str
90+
"""
91+
if not prompt:
92+
raise ValueError(PROMPT_EMPTY_EXCEPTION)
93+
94+
model = self.model_config.get('model')
95+
temperature = kwargs.get('temperature', 0.1)
96+
97+
response = self.client.chat(
98+
model=model,
99+
messages=[self.user_message(prompt)],
100+
options=Options(
101+
temperature=temperature
102+
)
103+
)
104+
105+
return response['message']['content']

tests/ollama_test.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
from ollama import Client, Options
4+
5+
from mindsql.llms import ILlm
6+
from mindsql.llms import Ollama
7+
from mindsql._utils.constants import PROMPT_EMPTY_EXCEPTION, OLLAMA_CONFIG_REQUIRED
8+
9+
10+
class TestOllama(unittest.TestCase):
11+
12+
def setUp(self):
13+
# Common setup for each test case
14+
self.model_config = {'model': 'sqlcoder'}
15+
self.client_config = {'host': 'http://localhost:11434/'}
16+
self.client_mock = MagicMock(spec=Client)
17+
18+
def test_initialization_with_client(self):
19+
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
20+
self.assertEqual(ollama.client, self.client_mock)
21+
self.assertIsNone(ollama.client_config)
22+
self.assertEqual(ollama.model_config, self.model_config)
23+
24+
def test_initialization_with_client_config(self):
25+
ollama = Ollama(model_config=self.model_config, client_config=self.client_config)
26+
self.assertIsNotNone(ollama.client)
27+
self.assertEqual(ollama.client_config, self.client_config)
28+
self.assertEqual(ollama.model_config, self.model_config)
29+
30+
def test_initialization_missing_client_and_client_config(self):
31+
with self.assertRaises(ValueError) as context:
32+
Ollama(model_config=self.model_config)
33+
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Client"))
34+
35+
def test_initialization_missing_model_config(self):
36+
with self.assertRaises(ValueError) as context:
37+
Ollama(model_config=None, client_config=self.client_config)
38+
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Model"))
39+
40+
def test_initialization_missing_model_name(self):
41+
with self.assertRaises(ValueError) as context:
42+
Ollama(model_config={}, client_config=self.client_config)
43+
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Model name"))
44+
45+
def test_system_message(self):
46+
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
47+
message = ollama.system_message("Test system message")
48+
self.assertEqual(message, {"role": "system", "content": "Test system message"})
49+
50+
def test_user_message(self):
51+
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
52+
message = ollama.user_message("Test user message")
53+
self.assertEqual(message, {"role": "user", "content": "Test user message"})
54+
55+
def test_assistant_message(self):
56+
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
57+
message = ollama.assistant_message("Test assistant message")
58+
self.assertEqual(message, {"role": "assistant", "content": "Test assistant message"})
59+
60+
@patch.object(Client, 'chat', return_value={'message': {'content': 'Test response'}})
61+
def test_invoke_success(self, mock_chat):
62+
ollama = Ollama(model_config=self.model_config, client=Client())
63+
response = ollama.invoke("Test prompt")
64+
65+
# Check if the response is as expected
66+
self.assertEqual(response, 'Test response')
67+
68+
# Verify that the chat method was called with the correct arguments
69+
mock_chat.assert_called_once_with(
70+
model=self.model_config['model'],
71+
messages=[{"role": "user", "content": "Test prompt"}],
72+
options=Options(temperature=0.1)
73+
)
74+
75+
def test_invoke_empty_prompt(self):
76+
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
77+
with self.assertRaises(ValueError) as context:
78+
ollama.invoke("")
79+
self.assertEqual(str(context.exception), PROMPT_EMPTY_EXCEPTION)
80+
81+
82+
if __name__ == '__main__':
83+
unittest.main()

0 commit comments

Comments
 (0)