Skip to content

Commit ff1c0ee

Browse files
committed
Merge branch 'staging' of https://github.com/mindsdb/mindsdb_python_sdk into staging
2 parents 6718cb0 + 6b0f0a3 commit ff1c0ee

File tree

5 files changed

+93
-23
lines changed

5 files changed

+93
-23
lines changed

examples/using_mindsdb_inference_with_text2sql_using_tools.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33

44
from mindsdb_sdk.utils.openai import (
5-
make_mindsdb_tool,
6-
execute_function_call,
7-
chat_completion_request,
8-
pretty_print_conversation)
5+
make_query_tool,
6+
execute_function_call,
7+
chat_completion_request,
8+
pretty_print_conversation)
99

1010
import mindsdb_sdk
1111
import os
@@ -33,10 +33,10 @@
3333
api_key=OPENAI_API_KEY
3434
)
3535

36-
database = con.databases.get("example_db")
36+
database = con.databases.get("mindsdb_demo_db")
3737
schema = get_table_schemas(database, included_tables=["airline_passenger_satisfaction"])
3838

39-
tools = [make_mindsdb_tool(schema)]
39+
tools = [make_query_tool(schema)]
4040

4141
SYSTEM_PROMPT = """You are a SQL expert. Given an input question, Answer user questions by generating SQL queries
4242
against the database schema provided in tools

mindsdb_sdk/utils/openai.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,31 @@
11
import json
2+
from logging import getLogger
3+
from typing import List
4+
5+
import openai
6+
from openai.types.chat import ChatCompletionToolChoiceOptionParam
27

38
from mindsdb_sdk.databases import Database
49
from tenacity import retry, wait_random_exponential, stop_after_attempt
510

611

7-
@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))
8-
def chat_completion_request(client, model, messages, tools=None, tool_choice=None):
12+
DEFAULT_RETRY_MULTIPLIER = 1
13+
DEFAULT_MAX_WAIT = 40
14+
DEFAULT_STOP_AFTER_ATTEMPT = 3
15+
16+
logger = getLogger(__name__)
17+
18+
19+
@retry(wait=wait_random_exponential(multiplier=DEFAULT_RETRY_MULTIPLIER, max=DEFAULT_MAX_WAIT), stop=stop_after_attempt(
20+
DEFAULT_RETRY_MULTIPLIER
21+
))
22+
def chat_completion_request(
23+
client: openai.OpenAI,
24+
model: str,
25+
messages: List[dict],
26+
tools: List = None,
27+
tool_choice: ChatCompletionToolChoiceOptionParam = None
28+
):
929
try:
1030
response = client.chat.completions.create(
1131
model=model,
@@ -15,8 +35,8 @@ def chat_completion_request(client, model, messages, tools=None, tool_choice=Non
1535
)
1636
return response
1737
except Exception as e:
18-
print("Unable to generate ChatCompletion response")
19-
print(f"Exception: {e}")
38+
logger.warning("Unable to generate ChatCompletion response")
39+
logger.warning(f"Exception: {e}")
2040
return e
2141

2242

@@ -29,7 +49,6 @@ def make_openai_tool(function: callable, description: str = None) -> dict:
2949
3050
:return: dictionary containing function metadata
3151
"""
32-
# You will need to pip install docstring-parser to use this function
3352

3453
import inspect
3554
import docstring_parser
@@ -77,7 +96,7 @@ def make_openai_tool(function: callable, description: str = None) -> dict:
7796
return function_dict
7897

7998

80-
def make_mindsdb_tool(schema: dict) -> dict:
99+
def make_query_tool(schema: dict) -> dict:
81100
"""
82101
Make an OpenAI tool for querying a database connection in MindsDB
83102
@@ -109,14 +128,14 @@ def make_mindsdb_tool(schema: dict) -> dict:
109128
}
110129

111130

112-
def litellm_text2sql_callback_tool(
131+
def make_data_tool(
113132
model: str,
114133
data_source: str,
115134
description: str,
116135
connection_args: dict
117136
):
118137
"""
119-
tool passing connection details for datasource to litellm callback
138+
tool passing mindsdb database connection details for datasource to litellm callback
120139
121140
:param model: model name for text to sql completion
122141
:param data_source: data source name
@@ -153,7 +172,7 @@ def litellm_text2sql_callback_tool(
153172
"description":"Data source name",
154173
},
155174
"connection_args":{
156-
"type":"object",
175+
"type":"string",
157176
"description":"Connection arguments for the data source",
158177
},
159178
"description":{
@@ -246,12 +265,12 @@ def pretty_print_conversation(messages):
246265

247266
for message in messages:
248267
if message["role"] == "system":
249-
print(colored(f"system: {message['content']}\n", role_to_color[message["role"]]))
268+
logger.info(colored(f"system: {message['content']}\n", role_to_color[message["role"]]))
250269
elif message["role"] == "user":
251-
print(colored(f"user: {message['content']}\n", role_to_color[message["role"]]))
270+
logger.info(colored(f"user: {message['content']}\n", role_to_color[message["role"]]))
252271
elif message["role"] == "assistant" and message.get("function_call"):
253-
print(colored(f"assistant: {message['function_call']}\n", role_to_color[message["role"]]))
272+
logger.info(colored(f"assistant: {message['function_call']}\n", role_to_color[message["role"]]))
254273
elif message["role"] == "assistant" and not message.get("function_call"):
255-
print(colored(f"assistant: {message['content']}\n", role_to_color[message["role"]]))
274+
logger.info(colored(f"assistant: {message['content']}\n", role_to_color[message["role"]]))
256275
elif message["role"] == "function":
257-
print(colored(f"function ({message['name']}): {message['content']}\n", role_to_color[message["role"]]))
276+
logger.info(colored(f"function ({message['name']}): {message['content']}\n", role_to_color[message["role"]]))

mindsdb_sdk/utils/table_schema.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ def get_table_schemas(database: Databases, included_tables: List[str] = None, n_
4848

4949
table_schemas = {}
5050
for table in tables:
51-
# Get the first 10 rows of the table
52-
table_df = database.get_table(table).fetch().head(n_rows)
53-
# Convert schema to list of dictionaries
51+
table_df = database.get_table(table).limit(n_rows).fetch()
5452
table_schemas[table] = get_dataframe_schema(table_df)
5553

5654
return table_schemas

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
requests
22
pandas >= 1.3.5
33
mindsdb-sql >= 0.13.0, < 0.14.0
4+
docstring-parser >= 0.7.3
5+
tenacity >= 8.0.1
6+
openai >= 1.15.0

tests/test_openai.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import json
2+
from unittest.mock import patch, MagicMock
3+
from mindsdb_sdk.utils import openai
4+
5+
6+
def test_chat_completion_request_success():
7+
mock_client = MagicMock()
8+
mock_client.chat.completions.create.return_value = "Test Response"
9+
response = openai.chat_completion_request(mock_client, "text-davinci-002", [{"role": "system", "content": "You are a helpful assistant."}])
10+
assert response == "Test Response"
11+
12+
13+
def test_make_openai_tool():
14+
def test_func(a: int, b: str) -> str:
15+
"""This is a test function"""
16+
return b * a
17+
tool = openai.make_openai_tool(test_func)
18+
assert tool["function"]["name"] == "test_func"
19+
assert tool["function"]["description"] == "This is a test function"
20+
assert tool["function"]["parameters"]["properties"]["a"]["type"] == "int"
21+
assert tool["function"]["parameters"]["properties"]["b"]["type"] == "str"
22+
23+
24+
def test_extract_sql_query():
25+
result = "SQLQuery: SELECT * FROM test_table\nSQLResult: [{'column1': 'value1', 'column2': 'value2'}]"
26+
query = openai.extract_sql_query(result)
27+
assert query == "SELECT * FROM test_table"
28+
29+
30+
def test_extract_sql_query_no_query():
31+
result = "SQLResult: [{'column1': 'value1', 'column2': 'value2'}]"
32+
query = openai.extract_sql_query(result)
33+
assert query is None
34+
35+
36+
@patch("mindsdb_sdk.utils.openai.query_database")
37+
def test_execute_function_call_query_database(mock_query_database):
38+
mock_query_database.return_value = "Test Result"
39+
mock_message = MagicMock()
40+
mock_message.tool_calls[0].function.name = "query_database"
41+
mock_message.tool_calls[0].function.arguments = json.dumps({"query": "SELECT * FROM test_table"})
42+
result = openai.execute_function_call(mock_message, MagicMock())
43+
assert result == "Test Result"
44+
45+
46+
def test_execute_function_call_no_function():
47+
mock_message = MagicMock()
48+
mock_message.tool_calls[0].function.name = "non_existent_function"
49+
result = openai.execute_function_call(mock_message, MagicMock())
50+
assert result == "Error: function non_existent_function does not exist"

0 commit comments

Comments
 (0)