|
| 1 | +from openai import OpenAI |
| 2 | +from mindsdb_sdk.utils.openai import extract_sql_query, query_database, chat_completion_request, \ |
| 3 | + pretty_print_conversation |
| 4 | + |
| 5 | +import mindsdb_sdk |
| 6 | +import os |
| 7 | + |
| 8 | +from mindsdb_sdk.utils.table_schema import get_table_schemas |
| 9 | + |
| 10 | +# generate the key at https://llm.mdb.ai |
| 11 | +MINDSDB_API_KEY = os.environ.get("MINDSDB_API_KEY") |
| 12 | +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") |
| 13 | + |
| 14 | +MODEL = "gpt-3.5-turbo" |
| 15 | + |
| 16 | +# the prompt should be a question that can be answered by the database |
| 17 | +SYSTEM_PROMPT = """You are a SQL expert. Given an input question, first create a syntactically correct SQL query to run, then look at the results of the query and return the answer to the input question. |
| 18 | +Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQL standards. You can order the results to return the most informative data in the database. |
| 19 | +Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as identifiers. |
| 20 | +Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. |
| 21 | +Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today". |
| 22 | +
|
| 23 | +Use the following format: |
| 24 | +
|
| 25 | +Question: <Question here> |
| 26 | +SQLQuery: <SQL Query to run> |
| 27 | +SQLResult: <Result of the SQLQuery> |
| 28 | +Answer: <Final answer here> |
| 29 | +
|
| 30 | +Only use the following tables: |
| 31 | +
|
| 32 | +{schema} |
| 33 | +""" |
| 34 | +PROMPT = "what was the average delay on arrivals?" |
| 35 | + |
| 36 | + |
| 37 | +def generate_system_prompt(system_prompt: str, schema: dict) -> dict: |
| 38 | + prompt = { |
| 39 | + "role":"system", |
| 40 | + "content":system_prompt.format(schema=schema) |
| 41 | + } |
| 42 | + return prompt |
| 43 | + |
| 44 | + |
| 45 | +def generate_user_prompt(query: str) -> dict: |
| 46 | + prompt = { |
| 47 | + "role":"user", |
| 48 | + "content":query |
| 49 | + } |
| 50 | + return prompt |
| 51 | + |
| 52 | + |
| 53 | +con = mindsdb_sdk.connect() |
| 54 | + |
| 55 | +# given database name, returns schema and database object |
| 56 | +# using example_db from mindsdb |
| 57 | + |
| 58 | +database = con.databases.get("example_db") |
| 59 | +schema = get_table_schemas(database, included_tables=["airline_passenger_satisfaction"]) |
| 60 | + |
| 61 | +# client_mindsdb_serve = OpenAI( |
| 62 | +# api_key=MINDSDB_API_KEY, |
| 63 | +# base_url="https://llm.mdb.ai" |
| 64 | +# ) |
| 65 | + |
| 66 | +client_mindsdb_serve = OpenAI( |
| 67 | + api_key=OPENAI_API_KEY |
| 68 | +) |
| 69 | + |
| 70 | +messages = [ |
| 71 | + generate_system_prompt(SYSTEM_PROMPT, schema), |
| 72 | + generate_user_prompt(PROMPT) |
| 73 | +] |
| 74 | + |
| 75 | +chat_response = chat_completion_request(client=client_mindsdb_serve, model=MODEL, messages=messages, tools=None, |
| 76 | + tool_choice=None) |
| 77 | + |
| 78 | +# extract the SQL query from the response |
| 79 | +query = extract_sql_query(chat_response.choices[0].message.content) |
| 80 | + |
| 81 | +result = query_database(database, query) |
| 82 | + |
| 83 | +# generate the user prompt with the query result, this will be used to generate the final response |
| 84 | +query = generate_user_prompt(f"Given this SQLResult: {str(result)} provide Answer: ") |
| 85 | + |
| 86 | +# add the query to the messages list |
| 87 | +messages.append(query) |
| 88 | + |
| 89 | +# generate the final response |
| 90 | +chat_completion_gpt = chat_completion_request(client=client_mindsdb_serve, model=MODEL, messages=messages, tools=None, |
| 91 | + tool_choice=None) |
| 92 | +response = chat_completion_gpt.choices[0].message.content |
| 93 | + |
| 94 | +pretty_print_conversation(messages) |
| 95 | + |
0 commit comments