diff --git a/.env.copy b/.env.copy index e83fac3..8579c4f 100644 --- a/.env.copy +++ b/.env.copy @@ -6,3 +6,4 @@ DB_USER=postgres DB_PASSWORD=postgres OPENAI_API_KEY=your_actual_openai_api_key_here +GEMINI_API_KEY=your_actual_gemini_api_key_here diff --git a/.gitignore b/.gitignore index 4f64a5c..81aeba9 100644 --- a/.gitignore +++ b/.gitignore @@ -275,3 +275,5 @@ Thumbs.db # config/local.py # uploads/ # media/ + +.venv311 \ No newline at end of file diff --git a/app.py b/app.py index b197072..cf5a1d6 100644 --- a/app.py +++ b/app.py @@ -233,9 +233,9 @@ id="query-strategy", options=[ {"label": "Schema-Based Querying", "value": "schema"}, + {"label": "Basic Text-to-SQL", "value": "basic"}, {"label": "RAG (Retrieval-Augmented Generation)", "value": "rag"}, {"label": "Visualize", "value": "visualize"}, - {"label": "RAG (Retrieval-Augmented Generation)", "value": "rag"}, {"label": "Multi-Table Join", "value": "multitablejoin"} ], value="schema", @@ -413,7 +413,11 @@ def update_chat(n_clicks, n_submit, input_value, chat_history, settings, connect try: # Create query engine - engine_config = {"openai_api_key": Config.OPENAI_API_KEY, "db_uri": SQLITE_DB_PATH} + engine_config = { + "OPENAI_API_KEY": Config.OPENAI_API_KEY, + "GEMINI_API_KEY": Config.GEMINI_API_KEY, + "db_uri": SQLITE_DB_PATH + } query_engine = query_engine_factory.create_query_engine(strategy, engine_config) # Create security guardrail if enabled diff --git a/config.py b/config.py index 79b4db0..077a5e0 100644 --- a/config.py +++ b/config.py @@ -17,6 +17,7 @@ class Config: # LLM Configuration OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') + GEMINI_API_KEY = os.getenv('GEMINI_API_KEY') ANTHROPIC_API_KEY = os.getenv('ANTHROPIC_API_KEY') # App Configuration diff --git a/database/query_engine.py b/database/query_engine.py index 0bf693e..a6848a2 100644 --- a/database/query_engine.py +++ b/database/query_engine.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, Any, Tuple, List import openai +import google.generativeai as genai # For Gemini API from openai import OpenAI import logging from datetime import datetime @@ -496,6 +497,153 @@ def execute_query(self, sql_query: str) -> Tuple[bool, Any]: +class BasicTextToSQLEngine(QueryEngine): + """Basic text-to-SQL using manual prompt construction with schema and few-shot examples""" + + def __init__(self, gemini_api_key: str): + self.gemini_api_key = gemini_api_key + genai.configure(api_key=gemini_api_key) # Configure Gemini API + logger.info("Initialized BasicTextToSQLEngine with Gemini API") + + def get_name(self) -> str: + return "Basic Text-to-SQL" + + def generate_query(self, user_query: str, context: Dict[str, Any]) -> Tuple[bool, str]: + """Generate SQL query using basic text-to-SQL with manual prompt construction""" + logger.info(f"Starting basic text-to-SQL generation for: '{user_query}'") + + try: + # Get database schema information + if not db_connection.is_connected(): + return False, "Not connected to database" + + # Get tables and their schemas + tables_success, tables_result = db_connection.get_tables() + if not tables_success: + return False, f"Failed to get tables: {tables_result}" + + if tables_result.empty: + return False, "No tables found in the database" + + # Build the manual prompt with schema and examples + prompt = self._build_manual_prompt(user_query, tables_result, context) + + # Call Gemini API with the constructed prompt (free tier compatible) + model = genai.GenerativeModel('gemini-1.5-flash') # Free tier model + response = model.generate_content( + prompt, + generation_config=genai.types.GenerationConfig( + max_output_tokens=300, # Reduced for free tier + temperature=0.1, + candidate_count=1, # Free tier supports only 1 candidate + ) + ) + sql_query = response.text.strip() + + # Fallback to OpenAI if Gemini fails (commented out) + # response = openai.ChatCompletion.create( + # model="gpt-3.5-turbo", + # messages=[ + # {"role": "user", "content": prompt} + # ], + # max_tokens=500, + # temperature=0.1 + # ) + # sql_query = response.choices[0].message.content.strip() + + # Clean up the response + if sql_query.startswith('```sql'): + sql_query = sql_query[6:] + if sql_query.endswith('```'): + sql_query = sql_query[:-3] + sql_query = sql_query.strip() + + logger.info(f"Generated SQL: {sql_query}") + return True, sql_query + + except Exception as e: + logger.error(f"Error in basic text-to-SQL generation: {str(e)}") + return False, f"Error generating SQL: {str(e)}" + + def execute_query(self, sql_query: str) -> Tuple[bool, Any]: + """Execute the generated SQL query""" + try: + return db_connection.execute_query(sql_query) + except Exception as e: + logger.error(f"Error executing query: {str(e)}") + return False, f"Error executing query: {str(e)}" + + def _build_manual_prompt(self, user_query: str, tables_df, context: Dict[str, Any]) -> str: + """Build a comprehensive manual prompt with schema and few-shot examples""" + db_type = context.get('db_type', 'postgresql') + + # Build full database schema + schema_info = self._build_full_schema(tables_df) + + # Create the manual prompt + prompt = f"""You are a MySQL expert. Your role is to generate a valid SQL query based on the user's natural language question. + +DATABASE SCHEMA: +{schema_info} + +FEW-SHOT EXAMPLES: + +Example 1: +Question: "Which customers in California spent the most last quarter?" +SQL: SELECT customer_name, SUM(amount) as total_spent FROM orders o JOIN customers c ON o.customer_id = c.id WHERE c.state = 'California' AND o.order_date >= DATE_SUB(NOW(), INTERVAL 3 MONTH) GROUP BY c.id, customer_name ORDER BY total_spent DESC LIMIT 10; + +Example 2: +Question: "Show me all books published after 2020" +SQL: SELECT title, author, publication_year FROM books WHERE publication_year > 2020 ORDER BY publication_year DESC; + +Example 3: +Question: "How many users have borrowed books in the last month?" +SQL: SELECT COUNT(DISTINCT user_id) as active_borrowers FROM book_loans WHERE loan_date >= DATE_SUB(NOW(), INTERVAL 1 MONTH); + +INSTRUCTIONS: +1. Generate a valid {db_type} SQL query +2. Use ONLY the tables and columns shown in the schema above +3. Include appropriate WHERE clauses, JOINs, and ORDER BY as needed +4. Add LIMIT clause if the query might return many rows +5. Use proper SQL syntax and formatting +6. Return only the SQL query, no explanations + +USER QUESTION: {user_query} + +SQL QUERY:""" + + return prompt + + def _build_full_schema(self, tables_df) -> str: + """Build complete database schema information""" + schema_parts = [] + + for _, table_row in tables_df.iterrows(): + table_name = table_row['table_name'] + schema_parts.append(f"\nTable: {table_name}") + + # Get column information for this table + schema_success, schema_result = db_connection.get_table_schema(table_name) + if schema_success and not schema_result.empty: + schema_parts.append("Columns:") + for _, col_row in schema_result.iterrows(): + col_name = col_row['column_name'] + col_type = col_row['data_type'] + is_nullable = col_row.get('is_nullable', 'YES') + col_default = col_row.get('column_default', '') + + col_info = f" - {col_name} ({col_type})" + if is_nullable == 'NO': + col_info += " NOT NULL" + if col_default: + col_info += f" DEFAULT {col_default}" + + schema_parts.append(col_info) + else: + schema_parts.append(" (Schema information not available)") + + return "\n".join(schema_parts) + class RAGQueryEngine(QueryEngine): """RAG-based query generation (placeholder)""" @@ -590,6 +738,12 @@ def create_query_engine(engine_type: str, config: Dict[str, Any]) -> QueryEngine if not api_key: raise ValueError("OpenAI API key required for schema-based querying") return SchemaBasedQueryEngine(api_key) + elif engine_type == "basic": + # Using Gemini API for Basic Text-to-SQL + gemini_key = config.get('GEMINI_API_KEY') + if not gemini_key: + raise ValueError("Gemini API key required for basic text-to-SQL") + return BasicTextToSQLEngine(gemini_key) elif engine_type == "rag": return RAGQueryEngine() elif engine_type == "multitablejoin": diff --git a/requirements.txt b/requirements.txt index 6962680..80f02ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,13 +3,14 @@ dash-bootstrap-components==1.5.0 dash-extensions==1.0.4 plotly==5.17.0 pandas==2.1.4 -numpy==1.25.2 +numpy>=1.26.4,<3 python-dotenv==1.0.0 requests==2.31.0 sqlalchemy==2.0.23 psycopg2-binary==2.9.9 pymysql==1.1.0 cryptography==41.0.7 +google-generativeai==0.3.2 openai==1.102.0 flask==2.3.3 gunicorn==21.2.0