Skip to content

Commit d6b57f6

Browse files
committed
Enhance logging and error handling in openai module
This commit introduces better logging and error handling in the openai module. The print statements have been replaced with logger calls which provide more insight and control over the output. Additionally, a new unit test file (test_openai.py) has been created to test the 'openai' utilities. Lastly, 'docstring-parser' dependency was added to requirements.txt.
1 parent 28658f7 commit d6b57f6

File tree

3 files changed

+81
-10
lines changed

3 files changed

+81
-10
lines changed

mindsdb_sdk/utils/openai.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,31 @@
11
import json
2+
from logging import getLogger
3+
from typing import List
4+
5+
import openai
6+
from openai.types.chat import ChatCompletionToolChoiceOptionParam
27

38
from mindsdb_sdk.databases import Database
49
from tenacity import retry, wait_random_exponential, stop_after_attempt
510

611

7-
@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))
8-
def chat_completion_request(client, model, messages, tools=None, tool_choice=None):
12+
DEFAULT_RETRY_MULTIPLIER = 1
13+
DEFAULT_MAX_WAIT = 40
14+
DEFAULT_STOP_AFTER_ATTEMPT = 3
15+
16+
logger = getLogger(__name__)
17+
18+
19+
@retry(wait=wait_random_exponential(multiplier=DEFAULT_RETRY_MULTIPLIER, max=DEFAULT_MAX_WAIT), stop=stop_after_attempt(
20+
DEFAULT_RETRY_MULTIPLIER
21+
))
22+
def chat_completion_request(
23+
client: openai.OpenAI,
24+
model: str,
25+
messages: List[dict],
26+
tools: List = None,
27+
tool_choice: ChatCompletionToolChoiceOptionParam = None
28+
):
929
try:
1030
response = client.chat.completions.create(
1131
model=model,
@@ -15,8 +35,8 @@ def chat_completion_request(client, model, messages, tools=None, tool_choice=Non
1535
)
1636
return response
1737
except Exception as e:
18-
print("Unable to generate ChatCompletion response")
19-
print(f"Exception: {e}")
38+
logger.warning("Unable to generate ChatCompletion response")
39+
logger.warning(f"Exception: {e}")
2040
return e
2141

2242

@@ -153,7 +173,7 @@ def litellm_text2sql_callback_tool(
153173
"description":"Data source name",
154174
},
155175
"connection_args":{
156-
"type":"object",
176+
"type":"string",
157177
"description":"Connection arguments for the data source",
158178
},
159179
"description":{
@@ -246,12 +266,12 @@ def pretty_print_conversation(messages):
246266

247267
for message in messages:
248268
if message["role"] == "system":
249-
print(colored(f"system: {message['content']}\n", role_to_color[message["role"]]))
269+
logger.info(colored(f"system: {message['content']}\n", role_to_color[message["role"]]))
250270
elif message["role"] == "user":
251-
print(colored(f"user: {message['content']}\n", role_to_color[message["role"]]))
271+
logger.info(colored(f"user: {message['content']}\n", role_to_color[message["role"]]))
252272
elif message["role"] == "assistant" and message.get("function_call"):
253-
print(colored(f"assistant: {message['function_call']}\n", role_to_color[message["role"]]))
273+
logger.info(colored(f"assistant: {message['function_call']}\n", role_to_color[message["role"]]))
254274
elif message["role"] == "assistant" and not message.get("function_call"):
255-
print(colored(f"assistant: {message['content']}\n", role_to_color[message["role"]]))
275+
logger.info(colored(f"assistant: {message['content']}\n", role_to_color[message["role"]]))
256276
elif message["role"] == "function":
257-
print(colored(f"function ({message['name']}): {message['content']}\n", role_to_color[message["role"]]))
277+
logger.info(colored(f"function ({message['name']}): {message['content']}\n", role_to_color[message["role"]]))

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
requests
22
pandas >= 1.3.5
33
mindsdb-sql >= 0.13.0, < 0.14.0
4+
docstring-parser >= 0.7.3

tests/test_openai.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import json
2+
from unittest.mock import patch, MagicMock
3+
from mindsdb_sdk.utils import openai
4+
5+
6+
def test_chat_completion_request_success():
7+
mock_client = MagicMock()
8+
mock_client.chat.completions.create.return_value = "Test Response"
9+
response = openai.chat_completion_request(mock_client, "text-davinci-002", [{"role": "system", "content": "You are a helpful assistant."}])
10+
assert response == "Test Response"
11+
12+
13+
def test_make_openai_tool():
14+
def test_func(a: int, b: str) -> str:
15+
"""This is a test function"""
16+
return b * a
17+
tool = openai.make_openai_tool(test_func)
18+
assert tool["function"]["name"] == "test_func"
19+
assert tool["function"]["description"] == "This is a test function"
20+
assert tool["function"]["parameters"]["properties"]["a"]["type"] == "int"
21+
assert tool["function"]["parameters"]["properties"]["b"]["type"] == "str"
22+
23+
24+
def test_extract_sql_query():
25+
result = "SQLQuery: SELECT * FROM test_table\nSQLResult: [{'column1': 'value1', 'column2': 'value2'}]"
26+
query = openai.extract_sql_query(result)
27+
assert query == "SELECT * FROM test_table"
28+
29+
30+
def test_extract_sql_query_no_query():
31+
result = "SQLResult: [{'column1': 'value1', 'column2': 'value2'}]"
32+
query = openai.extract_sql_query(result)
33+
assert query is None
34+
35+
36+
@patch("mindsdb_sdk.utils.openai.query_database")
37+
def test_execute_function_call_query_database(mock_query_database):
38+
mock_query_database.return_value = "Test Result"
39+
mock_message = MagicMock()
40+
mock_message.tool_calls[0].function.name = "query_database"
41+
mock_message.tool_calls[0].function.arguments = json.dumps({"query": "SELECT * FROM test_table"})
42+
result = openai.execute_function_call(mock_message, MagicMock())
43+
assert result == "Test Result"
44+
45+
46+
def test_execute_function_call_no_function():
47+
mock_message = MagicMock()
48+
mock_message.tool_calls[0].function.name = "non_existent_function"
49+
result = openai.execute_function_call(mock_message, MagicMock())
50+
assert result == "Error: function non_existent_function does not exist"

0 commit comments

Comments
 (0)