Skip to content

Commit 1d9197e

Browse files
author
Szymon Cyranik
committed
test(llm): add unit tests for Ollama client
1 parent 0fb2eda commit 1d9197e

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

tests/ollama_test.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
from ollama import Client, Options
4+
5+
from mindsql.llms import ILlm
6+
from mindsql.llms import Ollama
7+
from mindsql._utils.constants import PROMPT_EMPTY_EXCEPTION, OLLAMA_CONFIG_REQUIRED
8+
9+
10+
class TestOllama(unittest.TestCase):
11+
12+
def setUp(self):
13+
# Common setup for each test case
14+
self.model_config = {'model': 'sqlcoder'}
15+
self.client_config = {'host': 'http://localhost:11434/'}
16+
self.client_mock = MagicMock(spec=Client)
17+
18+
def test_initialization_with_client(self):
19+
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
20+
self.assertEqual(ollama.client, self.client_mock)
21+
self.assertIsNone(ollama.client_config)
22+
self.assertEqual(ollama.model_config, self.model_config)
23+
24+
def test_initialization_with_client_config(self):
25+
ollama = Ollama(model_config=self.model_config, client_config=self.client_config)
26+
self.assertIsNotNone(ollama.client)
27+
self.assertEqual(ollama.client_config, self.client_config)
28+
self.assertEqual(ollama.model_config, self.model_config)
29+
30+
def test_initialization_missing_client_and_client_config(self):
31+
with self.assertRaises(ValueError) as context:
32+
Ollama(model_config=self.model_config)
33+
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Client"))
34+
35+
def test_initialization_missing_model_config(self):
36+
with self.assertRaises(ValueError) as context:
37+
Ollama(model_config=None, client_config=self.client_config)
38+
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Model"))
39+
40+
def test_initialization_missing_model_name(self):
41+
with self.assertRaises(ValueError) as context:
42+
Ollama(model_config={}, client_config=self.client_config)
43+
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Model name"))
44+
45+
def test_system_message(self):
46+
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
47+
message = ollama.system_message("Test system message")
48+
self.assertEqual(message, {"role": "system", "content": "Test system message"})
49+
50+
def test_user_message(self):
51+
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
52+
message = ollama.user_message("Test user message")
53+
self.assertEqual(message, {"role": "user", "content": "Test user message"})
54+
55+
def test_assistant_message(self):
56+
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
57+
message = ollama.assistant_message("Test assistant message")
58+
self.assertEqual(message, {"role": "assistant", "content": "Test assistant message"})
59+
60+
@patch.object(Client, 'chat', return_value={'message': {'content': 'Test response'}})
61+
def test_invoke_success(self, mock_chat):
62+
ollama = Ollama(model_config=self.model_config, client=Client())
63+
response = ollama.invoke("Test prompt")
64+
65+
# Check if the response is as expected
66+
self.assertEqual(response, 'Test response')
67+
68+
# Verify that the chat method was called with the correct arguments
69+
mock_chat.assert_called_once_with(
70+
model=self.model_config['model'],
71+
messages=[{"role": "user", "content": "Test prompt"}],
72+
options=Options(temperature=0.1)
73+
)
74+
75+
def test_invoke_empty_prompt(self):
76+
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
77+
with self.assertRaises(ValueError) as context:
78+
ollama.invoke("")
79+
self.assertEqual(str(context.exception), PROMPT_EMPTY_EXCEPTION)
80+
81+
82+
if __name__ == '__main__':
83+
unittest.main()

0 commit comments

Comments
 (0)