Skip to content

Commit bbfee91

Browse files
committed
Add MindsDB inference example and utility functions
This commit adds an example script showing usage of MindsDB for inference using external tools. It also introduces two utility functions in 'mindsdb_sdk' package to help in interpreting functions and obtaining database table schemas. These updates aim to make it easier for developers to leverage MindsDB in their ML projects.
1 parent dd2684a commit bbfee91

File tree

3 files changed

+207
-0
lines changed

3 files changed

+207
-0
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from openai import OpenAI, OpenAIError
2+
from mindsdb_sdk.utils.table_schema import get_table_schemas
3+
from mindsdb_sdk.utils.openai import make_openai_tool
4+
import mindsdb_sdk
5+
import os
6+
7+
# generate the key at https://llm.mdb.ai
8+
MINDSDB_API_KEY = os.environ.get("MINDSDB_API_KEY")
9+
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
10+
11+
MODEL = "gpt-3.5-turbo"
12+
# text2sql prompt here (e.g. "What is the average satisfaction of passengers in the airline_passenger_satisfaction table?")
13+
SYSTEM_PROMPT = """You are a SQL expert. Given an input question, first create a syntactically correct SQL query to run,
14+
then look at the results of the query and return the answer to the input question.
15+
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the
16+
LIMIT clause as per SQL standards. You can order the results to return the most informative data in the database.
17+
Never query for all columns from a table. You must query only the columns that are needed to answer the question.
18+
Wrap each column name in backticks (`) to denote them as identifiers.
19+
Pay attention to use only the column names you can see in the tables below.
20+
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
21+
Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today".
22+
23+
Use the following format:
24+
25+
Question: <Question here>
26+
SQLQuery: <SQL Query to run>
27+
SQLResult: <Result of the SQLQuery>
28+
Answer: <Final answer here>
29+
30+
Only use the following tables:
31+
32+
{schema}
33+
"""
34+
PROMPT = "what was the average delay on arrivals?"
35+
36+
37+
def generate_system_prompt(system_prompt, schema):
38+
prompt = {
39+
"role": "system",
40+
"content": system_prompt.format(schema=schema)
41+
}
42+
return prompt
43+
44+
45+
def generate_user_prompt(query):
46+
prompt = {
47+
"role": "user",
48+
"content": query
49+
}
50+
return prompt
51+
52+
53+
def extract_sql_query(result):
54+
# Split the result into lines
55+
lines = result.split('\n')
56+
57+
# Initialize an empty string to hold the query
58+
query = ""
59+
60+
# Initialize a flag to indicate whether we're currently reading the query
61+
reading_query = False
62+
63+
# Iterate over the lines
64+
for line in lines:
65+
# If the line starts with "SQLQuery:", start reading the query
66+
if line.startswith("SQLQuery:"):
67+
query = line[len("SQLQuery:"):].strip()
68+
reading_query = True
69+
# If the line starts with "SQLResult:", stop reading the query
70+
elif line.startswith("SQLResult:"):
71+
break
72+
# If we're currently reading the query, append the line to the query
73+
elif reading_query:
74+
query += " " + line.strip()
75+
76+
# If no line starts with "SQLQuery:", return None
77+
if query == "":
78+
return None
79+
80+
return query
81+
82+
83+
con = mindsdb_sdk.connect()
84+
85+
database = con.databases.get(name="example_db")
86+
schema = get_table_schemas(database, included_tables=["airline_passenger_satisfaction"])
87+
88+
try:
89+
# client_mindsdb_serve = OpenAI(
90+
# api_key=MINDSDB_API_KEY,
91+
# base_url="https://llm.mdb.ai"
92+
# )
93+
94+
client_mindsdb_serve = OpenAI(
95+
api_key=OPENAI_API_KEY
96+
)
97+
98+
chat_completion_gpt = client_mindsdb_serve.chat.completions.create(
99+
messages=[
100+
generate_system_prompt(SYSTEM_PROMPT, schema),
101+
generate_user_prompt(PROMPT)
102+
],
103+
model=MODEL
104+
)
105+
106+
response = chat_completion_gpt.choices[0].message.content
107+
108+
query = extract_sql_query(response)
109+
110+
print(f"Generated SQL query: {query}")
111+
112+
except OpenAIError as e:
113+
raise OpenAIError(f"An error occurred with the MindsDB Serve API: {e}")
114+
115+
result = database.query(query).fetch()
116+
print(result)

mindsdb_sdk/utils/openai.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
2+
import inspect
3+
import docstring_parser
4+
5+
6+
def make_openai_tool(function: callable):
7+
"""
8+
Make an OpenAI tool for a function
9+
10+
:param function: function to generate metadata for
11+
:return: dictionary containing function metadata
12+
"""
13+
params = inspect.signature(function).parameters
14+
docstring = docstring_parser.parse(function.__doc__)
15+
16+
function_dict = {
17+
"type":"function",
18+
"function":{
19+
"name":function.__name__,
20+
"description":docstring.short_description,
21+
"parameters":{
22+
"type":"object",
23+
"properties":{},
24+
"required":[]
25+
}
26+
}
27+
}
28+
29+
for name, param in params.items():
30+
param_description = next((p.description for p in docstring.params if p.arg_name == name), '')
31+
32+
# convert annotation type to string
33+
if param.annotation is not inspect.Parameter.empty:
34+
if inspect.isclass(param.annotation):
35+
param_type = param.annotation.__name__
36+
else:
37+
param_type = str(param.annotation)
38+
else:
39+
param_type = None
40+
41+
function_dict["function"]["parameters"]["properties"][name] = {
42+
"type":param_type,
43+
"description":param_description
44+
}
45+
46+
# Check if parameter is required
47+
if param.default == inspect.Parameter.empty:
48+
function_dict["function"]["parameters"]["required"].append(name)
49+
50+
return function_dict
51+

mindsdb_sdk/utils/table_schema.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import List
2+
from mindsdb_sdk.databases import Databases
3+
4+
5+
def get_dataframe_schema(df):
6+
# Get the dtypes Series
7+
try:
8+
df = df.convert_dtypes()
9+
except Exception as e:
10+
raise f"Error converting dtypes: {e}"
11+
12+
dtypes = df.dtypes
13+
14+
# Convert the dtypes Series into a list of dictionaries
15+
schema = [{"name": column, "type": dtype.name} for column, dtype in dtypes.items()]
16+
17+
return schema
18+
19+
20+
def get_table_schemas(database: Databases, included_tables: List[str] = None):
21+
"""
22+
Get table schemas from a database
23+
24+
:param database: database object
25+
:param included_tables: list of table names to get schemas for
26+
:return: dictionary containing table schemas
27+
"""
28+
29+
tables = [table.name for table in database.tables.list()]
30+
31+
if included_tables:
32+
tables = [table for table in tables if table in included_tables]
33+
34+
table_schemas = {}
35+
for table in tables:
36+
table_df = database.get_table(table).fetch()
37+
# Convert schema to list of dictionaries
38+
table_schemas[table] = get_dataframe_schema(table_df)
39+
40+
return table_schemas

0 commit comments

Comments
 (0)