Skip to content

Commit c456587

Browse files
committed
Add example using MindsDB for text2SQL tasks and update util functions
This commit introduces an example using the MindsDB and OpenAI APIs to perform text2SQL tasks. Also, several utility function improvements are applied: retries for OpenAI chat completion requests, function execution, extracting SQL queries, and pretty print for the conversation history. Lastly, the file `using_mindsdb_llm_inference_with_tools.py` has been renamed to `using_mindsdb_inference_with_text2sql_using_tools.py` for better clarity.
1 parent bbfee91 commit c456587

File tree

4 files changed

+282
-50
lines changed

4 files changed

+282
-50
lines changed
Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
from openai import OpenAI, OpenAIError
2-
from mindsdb_sdk.utils.table_schema import get_table_schemas
3-
from mindsdb_sdk.utils.openai import make_openai_tool
2+
from mindsdb_sdk.utils.openai import extract_sql_query, make_openai_tool, query_database
3+
44
import mindsdb_sdk
55
import os
66

7+
from mindsdb_sdk.utils.table_schema import get_table_schemas
8+
79
# generate the key at https://llm.mdb.ai
810
MINDSDB_API_KEY = os.environ.get("MINDSDB_API_KEY")
911
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
1012

1113
MODEL = "gpt-3.5-turbo"
12-
# text2sql prompt here (e.g. "What is the average satisfaction of passengers in the airline_passenger_satisfaction table?")
14+
15+
# the prompt should be a question that can be answered by the database
1316
SYSTEM_PROMPT = """You are a SQL expert. Given an input question, first create a syntactically correct SQL query to run,
1417
then look at the results of the query and return the answer to the input question.
1518
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the
@@ -34,55 +37,28 @@
3437
PROMPT = "what was the average delay on arrivals?"
3538

3639

37-
def generate_system_prompt(system_prompt, schema):
40+
def generate_system_prompt(system_prompt: str, schema: dict) -> dict:
3841
prompt = {
3942
"role": "system",
4043
"content": system_prompt.format(schema=schema)
4144
}
4245
return prompt
4346

4447

45-
def generate_user_prompt(query):
48+
def generate_user_prompt(query: str) -> dict:
4649
prompt = {
4750
"role": "user",
4851
"content": query
4952
}
5053
return prompt
5154

5255

53-
def extract_sql_query(result):
54-
# Split the result into lines
55-
lines = result.split('\n')
56-
57-
# Initialize an empty string to hold the query
58-
query = ""
59-
60-
# Initialize a flag to indicate whether we're currently reading the query
61-
reading_query = False
62-
63-
# Iterate over the lines
64-
for line in lines:
65-
# If the line starts with "SQLQuery:", start reading the query
66-
if line.startswith("SQLQuery:"):
67-
query = line[len("SQLQuery:"):].strip()
68-
reading_query = True
69-
# If the line starts with "SQLResult:", stop reading the query
70-
elif line.startswith("SQLResult:"):
71-
break
72-
# If we're currently reading the query, append the line to the query
73-
elif reading_query:
74-
query += " " + line.strip()
75-
76-
# If no line starts with "SQLQuery:", return None
77-
if query == "":
78-
return None
79-
80-
return query
81-
82-
8356
con = mindsdb_sdk.connect()
8457

85-
database = con.databases.get(name="example_db")
58+
# given database name, returns schema and database object
59+
# using example_db from mindsdb
60+
61+
database = con.databases.get("example_db")
8662
schema = get_table_schemas(database, included_tables=["airline_passenger_satisfaction"])
8763

8864
try:
@@ -95,22 +71,43 @@ def extract_sql_query(result):
9571
api_key=OPENAI_API_KEY
9672
)
9773

74+
messages = [
75+
generate_system_prompt(SYSTEM_PROMPT, schema),
76+
generate_user_prompt(PROMPT)
77+
]
78+
9879
chat_completion_gpt = client_mindsdb_serve.chat.completions.create(
99-
messages=[
100-
generate_system_prompt(SYSTEM_PROMPT, schema),
101-
generate_user_prompt(PROMPT)
102-
],
80+
messages=messages,
10381
model=MODEL
10482
)
10583

10684
response = chat_completion_gpt.choices[0].message.content
10785

86+
# extract the SQL query from the response
10887
query = extract_sql_query(response)
10988

11089
print(f"Generated SQL query: {query}")
11190

11291
except OpenAIError as e:
11392
raise OpenAIError(f"An error occurred with the MindsDB Serve API: {e}")
11493

115-
result = database.query(query).fetch()
116-
print(result)
94+
result = query_database(database, query)
95+
96+
# format the result to be displayed in the prompt
97+
query_result = "SQLResult: " + str(result)
98+
99+
# generate the user prompt with the query result, this will be used to generate the final response
100+
query = generate_user_prompt(query_result)
101+
102+
# add the query to the messages list
103+
messages.append(query)
104+
105+
# generate the final response
106+
chat_completion_gpt = client_mindsdb_serve.chat.completions.create(
107+
messages=messages,
108+
model=MODEL
109+
)
110+
111+
response = chat_completion_gpt.choices[0].message.content
112+
113+
print(response)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from openai import OpenAI
2+
3+
4+
from mindsdb_sdk.utils.openai import (
5+
make_mindsdb_tool,
6+
execute_function_call,
7+
chat_completion_request,
8+
pretty_print_conversation)
9+
10+
import mindsdb_sdk
11+
import os
12+
13+
from mindsdb_sdk.utils.table_schema import get_table_schemas
14+
15+
# generate the key at https://llm.mdb.ai
16+
MINDSDB_API_KEY = os.environ.get("MINDSDB_API_KEY")
17+
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
18+
19+
MODEL = "gpt-3.5-turbo"
20+
21+
22+
con = mindsdb_sdk.connect()
23+
24+
# given database name, returns schema and database object
25+
# using example_db from mindsdb
26+
27+
# client_mindsdb_serve = OpenAI(
28+
# api_key=MINDSDB_API_KEY,
29+
# base_url="https://llm.mdb.ai"
30+
# )
31+
32+
client_mindsdb_serve = OpenAI(
33+
api_key=OPENAI_API_KEY
34+
)
35+
36+
database = con.databases.get("example_db")
37+
schema = get_table_schemas(database, included_tables=["airline_passenger_satisfaction"])
38+
39+
tools = [make_mindsdb_tool(schema)]
40+
41+
SYSTEM_PROMPT = """You are a SQL expert. Given an input question, Answer user questions by generating SQL queries
42+
against the database schema provided in tools
43+
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the
44+
LIMIT clause as per SQL standards. You can order the results to return the most informative data in the database.
45+
Never query for all columns from a table. You must query only the columns that are needed to answer the question.
46+
Wrap each column name in backticks (`) to denote them as identifiers.
47+
Pay attention to use only the column names you can see in the tables below.
48+
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
49+
Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today"."""
50+
51+
messages = [{
52+
"role":"system", "content":SYSTEM_PROMPT
53+
}, {"role":"user", "content":"what was the average delay on arrivals?"}]
54+
55+
chat_response = chat_completion_request(client=client_mindsdb_serve, model=MODEL, messages=messages, tools=tools, tool_choice=None)
56+
57+
assistant_message = chat_response.choices[0].message
58+
59+
assistant_message.content = str(assistant_message.tool_calls[0].function)
60+
61+
messages.append({"role": assistant_message.role, "content": assistant_message.content})
62+
63+
if assistant_message.tool_calls:
64+
results = execute_function_call(message=assistant_message, database=database)
65+
messages.append({
66+
"role": "function", "tool_call_id": assistant_message.tool_calls[0].id,
67+
"name": assistant_message.tool_calls[0].function.name,
68+
"content": results
69+
})
70+
71+
pretty_print_conversation(messages)

mindsdb_sdk/utils/openai.py

Lines changed: 153 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,50 @@
1+
import json
12

2-
import inspect
3-
import docstring_parser
3+
from mindsdb_sdk.databases import Database
4+
from tenacity import retry, wait_random_exponential, stop_after_attempt
45

56

6-
def make_openai_tool(function: callable):
7+
@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))
8+
def chat_completion_request(client, model, messages, tools=None, tool_choice=None):
9+
try:
10+
response = client.chat.completions.create(
11+
model=model,
12+
messages=messages,
13+
tools=tools,
14+
tool_choice=tool_choice,
15+
)
16+
return response
17+
except Exception as e:
18+
print("Unable to generate ChatCompletion response")
19+
print(f"Exception: {e}")
20+
return e
21+
22+
23+
def make_openai_tool(function: callable, description: str = None) -> dict:
724
"""
8-
Make an OpenAI tool for a function
25+
Make a generic OpenAI tool for a function
926
1027
:param function: function to generate metadata for
28+
:param description: description of the function
29+
1130
:return: dictionary containing function metadata
1231
"""
32+
# You will need to pip install docstring-parser to use this function
33+
34+
import inspect
35+
import docstring_parser
36+
1337
params = inspect.signature(function).parameters
1438
docstring = docstring_parser.parse(function.__doc__)
1539

40+
# Get the first line of the docstring as the function description or use the user-provided description
41+
function_description = description or docstring.short_description
42+
1643
function_dict = {
1744
"type":"function",
1845
"function":{
1946
"name":function.__name__,
20-
"description":docstring.short_description,
47+
"description":function_description,
2148
"parameters":{
2249
"type":"object",
2350
"properties":{},
@@ -49,3 +76,124 @@ def make_openai_tool(function: callable):
4976

5077
return function_dict
5178

79+
80+
def make_mindsdb_tool(schema: dict) -> dict:
81+
"""
82+
Make an OpenAI tool for querying a database connection in MindsDB
83+
84+
:param schema: database schema
85+
86+
:return: dictionary containing function metadata for openai tools
87+
"""
88+
return {
89+
"type":"function",
90+
"function":{
91+
"name":"query_database",
92+
"description":"Use this function to answer user questions. Input should be a fully formed SQL query.",
93+
"parameters":{
94+
"type":"object",
95+
"properties":{
96+
"query":{
97+
"type":"string",
98+
"description":f"""
99+
SQL query extracting info to answer the user's question.
100+
SQL should be written using this database schema:
101+
{schema}
102+
The query should be returned in plain text, not in JSON.
103+
""",
104+
}
105+
},
106+
"required":["query"],
107+
},
108+
}
109+
}
110+
111+
112+
def extract_sql_query(result: str) -> str:
113+
"""
114+
Extract the SQL query from an openai result string
115+
116+
:param result: OpenAI result string
117+
:return: SQL query string
118+
"""
119+
# Split the result into lines
120+
lines = result.split('\n')
121+
122+
# Initialize an empty string to hold the query
123+
query = ""
124+
125+
# Initialize a flag to indicate whether we're currently reading the query
126+
reading_query = False
127+
128+
# Iterate over the lines
129+
for line in lines:
130+
# If the line starts with "SQLQuery:", start reading the query
131+
if line.startswith("SQLQuery:"):
132+
query = line[len("SQLQuery:"):].strip()
133+
reading_query = True
134+
# If the line starts with "SQLResult:", stop reading the query
135+
elif line.startswith("SQLResult:"):
136+
break
137+
# If we're currently reading the query, append the line to the query
138+
elif reading_query:
139+
query += " " + line.strip()
140+
141+
# If no line starts with "SQLQuery:", return None
142+
if query == "":
143+
return None
144+
145+
return query
146+
147+
148+
def query_database(database: Database, query: str) -> str:
149+
"""
150+
Execute a query on a database connection
151+
152+
:param database: mindsdb Database object
153+
:param query: SQL query string
154+
155+
:return: query results as a string
156+
"""
157+
try:
158+
results = str(
159+
database.query(query).fetch()
160+
)
161+
except Exception as e:
162+
results = f"query failed with error: {e}"
163+
return results
164+
165+
166+
def execute_function_call(message, database: Database = None) -> str:
167+
"""
168+
Execute a function call in a message
169+
170+
"""
171+
if message.tool_calls[0].function.name == "query_database":
172+
query = json.loads(message.tool_calls[0].function.arguments)["query"]
173+
results = query_database(database, query)
174+
else:
175+
results = f"Error: function {message.tool_calls[0].function.name} does not exist"
176+
return results
177+
178+
179+
def pretty_print_conversation(messages):
180+
# you will need to pip install termcolor
181+
from termcolor import colored
182+
role_to_color = {
183+
"system":"red",
184+
"user":"green",
185+
"assistant":"blue",
186+
"function":"magenta",
187+
}
188+
189+
for message in messages:
190+
if message["role"] == "system":
191+
print(colored(f"system: {message['content']}\n", role_to_color[message["role"]]))
192+
elif message["role"] == "user":
193+
print(colored(f"user: {message['content']}\n", role_to_color[message["role"]]))
194+
elif message["role"] == "assistant" and message.get("function_call"):
195+
print(colored(f"assistant: {message['function_call']}\n", role_to_color[message["role"]]))
196+
elif message["role"] == "assistant" and not message.get("function_call"):
197+
print(colored(f"assistant: {message['content']}\n", role_to_color[message["role"]]))
198+
elif message["role"] == "function":
199+
print(colored(f"function ({message['name']}): {message['content']}\n", role_to_color[message["role"]]))

0 commit comments

Comments
 (0)