Skip to content

Commit 1b5f51d

Browse files
committed
Refactor the mindsdb inference script
The commit updates the `using_mindsdb_inference_with_text2sql_prompt.py` script. Important changes include removing redundant lines and changing the way query results are queried. Also, the function is updated to obtain chat results using a 'chat_completion
1 parent c456587 commit 1b5f51d

File tree

1 file changed

+28
-48
lines changed

1 file changed

+28
-48
lines changed
Lines changed: 28 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from openai import OpenAI, OpenAIError
2-
from mindsdb_sdk.utils.openai import extract_sql_query, make_openai_tool, query_database
1+
from openai import OpenAI
2+
from mindsdb_sdk.utils.openai import extract_sql_query, query_database, chat_completion_request
33

44
import mindsdb_sdk
55
import os
@@ -13,14 +13,10 @@
1313
MODEL = "gpt-3.5-turbo"
1414

1515
# the prompt should be a question that can be answered by the database
16-
SYSTEM_PROMPT = """You are a SQL expert. Given an input question, first create a syntactically correct SQL query to run,
17-
then look at the results of the query and return the answer to the input question.
18-
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the
19-
LIMIT clause as per SQL standards. You can order the results to return the most informative data in the database.
20-
Never query for all columns from a table. You must query only the columns that are needed to answer the question.
21-
Wrap each column name in backticks (`) to denote them as identifiers.
22-
Pay attention to use only the column names you can see in the tables below.
23-
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
16+
SYSTEM_PROMPT = """You are a SQL expert. Given an input question, first create a syntactically correct SQL query to run, then look at the results of the query and return the answer to the input question.
17+
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQL standards. You can order the results to return the most informative data in the database.
18+
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as identifiers.
19+
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
2420
Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today".
2521
2622
Use the following format:
@@ -39,16 +35,16 @@
3935

4036
def generate_system_prompt(system_prompt: str, schema: dict) -> dict:
4137
prompt = {
42-
"role": "system",
43-
"content": system_prompt.format(schema=schema)
38+
"role":"system",
39+
"content":system_prompt.format(schema=schema)
4440
}
4541
return prompt
4642

4743

4844
def generate_user_prompt(query: str) -> dict:
4945
prompt = {
50-
"role": "user",
51-
"content": query
46+
"role":"user",
47+
"content":query
5248
}
5349
return prompt
5450

@@ -61,53 +57,37 @@ def generate_user_prompt(query: str) -> dict:
6157
database = con.databases.get("example_db")
6258
schema = get_table_schemas(database, included_tables=["airline_passenger_satisfaction"])
6359

64-
try:
65-
# client_mindsdb_serve = OpenAI(
66-
# api_key=MINDSDB_API_KEY,
67-
# base_url="https://llm.mdb.ai"
68-
# )
60+
# client_mindsdb_serve = OpenAI(
61+
# api_key=MINDSDB_API_KEY,
62+
# base_url="https://llm.mdb.ai"
63+
# )
6964

70-
client_mindsdb_serve = OpenAI(
71-
api_key=OPENAI_API_KEY
72-
)
73-
74-
messages = [
75-
generate_system_prompt(SYSTEM_PROMPT, schema),
76-
generate_user_prompt(PROMPT)
77-
]
78-
79-
chat_completion_gpt = client_mindsdb_serve.chat.completions.create(
80-
messages=messages,
81-
model=MODEL
82-
)
83-
84-
response = chat_completion_gpt.choices[0].message.content
65+
client_mindsdb_serve = OpenAI(
66+
api_key=OPENAI_API_KEY
67+
)
8568

86-
# extract the SQL query from the response
87-
query = extract_sql_query(response)
69+
messages = [
70+
generate_system_prompt(SYSTEM_PROMPT, schema),
71+
generate_user_prompt(PROMPT)
72+
]
8873

89-
print(f"Generated SQL query: {query}")
74+
chat_response = chat_completion_request(client=client_mindsdb_serve, model=MODEL, messages=messages, tools=None,
75+
tool_choice=None)
9076

91-
except OpenAIError as e:
92-
raise OpenAIError(f"An error occurred with the MindsDB Serve API: {e}")
77+
# extract the SQL query from the response
78+
query = extract_sql_query(chat_response.choices[0].message.content)
9379

9480
result = query_database(database, query)
9581

96-
# format the result to be displayed in the prompt
97-
query_result = "SQLResult: " + str(result)
98-
9982
# generate the user prompt with the query result, this will be used to generate the final response
100-
query = generate_user_prompt(query_result)
83+
query = generate_user_prompt(f"Given this SQLResult: {str(result)} provide Answer: ")
10184

10285
# add the query to the messages list
10386
messages.append(query)
10487

10588
# generate the final response
106-
chat_completion_gpt = client_mindsdb_serve.chat.completions.create(
107-
messages=messages,
108-
model=MODEL
109-
)
110-
89+
chat_completion_gpt = chat_completion_request(client=client_mindsdb_serve, model=MODEL, messages=messages, tools=None,
90+
tool_choice=None)
11191
response = chat_completion_gpt.choices[0].message.content
11292

11393
print(response)

0 commit comments

Comments
 (0)