33import dotenv
44import os
55import sys
6- import codecs
76import json
8- import uuid
97import urllib .parse
108import anthropic
119from mcp import ClientSession
1210from mcp .client .sse import sse_client
1311from tabulate import tabulate
1412
15- # This function is no longer needed as we're using the connect tool
16- # Kept for reference but not used
17- def postgres_connection_to_uuid (connection_string , namespace = uuid .NAMESPACE_URL ):
18- """
19- Convert a PostgreSQL connection string into a deterministic Version 5 UUID.
20- Includes both connection credentials (netloc) and database name (path).
21- """
22- # Make sure connection_string has proper protocol prefix
23- if not connection_string .startswith ("postgresql://" ):
24- connection_string = f"postgresql://{ connection_string } "
25-
26- # Parse the connection string
27- parsed = urllib .parse .urlparse (connection_string )
28-
29- # Extract the netloc (user:password@host:port) and path (database name)
30- connection_id_string = parsed .netloc + parsed .path
31-
32- # Create a Version 5 UUID (SHA-1 based)
33- result_uuid = uuid .uuid5 (namespace , connection_id_string )
34-
35- return str (result_uuid )
36-
37- async def fetch_schema_info (session , conn_id ):
38- """Fetch database schema information from the MCP server."""
39- schema_info = []
40-
41- # First get all schemas
13+ async def fetch_database_hierarchy (session , conn_id ):
14+ """Fetch complete database structure from the MCP server using the new resource."""
4215 try :
43- schemas_resource = f"pgmcp://{ conn_id } /schemas"
44- schemas_response = await session .read_resource (schemas_resource )
16+ # Get the complete database information
17+ db_resource = f"pgmcp://{ conn_id } /"
18+ db_response = await session .read_resource (db_resource )
4519
46- schemas_content = None
47- if hasattr (schemas_response , 'content' ) and schemas_response .content :
48- schemas_content = schemas_response .content
49- elif hasattr (schemas_response , 'contents' ) and schemas_response .contents :
50- schemas_content = schemas_response .contents
20+ content = None
21+ if hasattr (db_response , 'content' ) and db_response .content :
22+ content = db_response .content
23+ elif hasattr (db_response , 'contents' ) and db_response .contents :
24+ content = db_response .contents
5125
52- if schemas_content :
53- content = schemas_content [0 ]
54- if hasattr (content , 'text' ):
55- schemas = json .loads (content .text )
56-
57- # For each schema, get its tables
58- for schema in schemas :
59- schema_name = schema .get ('schema_name' )
60- schema_description = schema .get ('description' , '' )
61-
62- # Fetch tables for this schema
63- tables_resource = f"pgmcp://{ conn_id } /schemas/{ schema_name } /tables"
64- tables_response = await session .read_resource (tables_resource )
65-
66- tables_content = None
67- if hasattr (tables_response , 'content' ) and tables_response .content :
68- tables_content = tables_response .content
69- elif hasattr (tables_response , 'contents' ) and tables_response .contents :
70- tables_content = tables_response .contents
71-
72- if tables_content :
73- content = tables_content [0 ]
74- if hasattr (content , 'text' ):
75- tables = json .loads (content .text )
76-
77- # For each table, get its columns
78- for table in tables :
79- table_name = table .get ('table_name' )
80- table_description = table .get ('description' , '' )
81-
82- # Fetch columns for this table
83- columns_resource = f"pgmcp://{ conn_id } /schemas/{ schema_name } /tables/{ table_name } /columns"
84- columns_response = await session .read_resource (columns_resource )
85-
86- columns = []
87- columns_content = None
88- if hasattr (columns_response , 'content' ) and columns_response .content :
89- columns_content = columns_response .content
90- elif hasattr (columns_response , 'contents' ) and columns_response .contents :
91- columns_content = columns_response .contents
92-
93- if columns_content :
94- content = columns_content [0 ]
95- if hasattr (content , 'text' ):
96- columns = json .loads (content .text )
97-
98- # Add table with its columns to schema info
99- schema_info .append ({
100- 'schema' : schema_name ,
101- 'table' : table_name ,
102- 'description' : table_description ,
103- 'columns' : columns
104- })
26+ if content :
27+ content_item = content [0 ]
28+ if hasattr (content_item , 'text' ):
29+ return json .loads (content_item .text )
10530
106- return schema_info
31+ return None
10732 except Exception as e :
108- print (f"Error fetching schema information : { e } " )
109- return []
33+ print (f"Error fetching database hierarchy : { e } " )
34+ return None
11035
111- def format_schema_for_prompt ( schema_info ):
112- """Format schema information as a string for the prompt ."""
113- if not schema_info :
114- return "No schema information available."
36+ def format_database_hierarchy ( db_structure ):
37+ """Format the database structure in a hierarchical console output ."""
38+ if not db_structure or 'schemas' not in db_structure :
39+ return "No database structure available."
11540
116- schema_text = "DATABASE SCHEMA :\n \n "
41+ output = "DATABASE HIERARCHY :\n \n "
11742
118- for table_info in schema_info :
119- schema_name = table_info .get ('schema' )
120- table_name = table_info .get ('table' )
121- description = table_info .get ('description' , '' )
43+ for schema in db_structure ['schemas' ]:
44+ schema_name = schema ['name' ]
45+ schema_desc = schema .get ('description' , '' )
12246
123- schema_text += f"Table: { schema_name } .{ table_name } "
124- if description :
125- schema_text += f" - { description } "
126- schema_text += "\n "
47+ # Add schema header
48+ output += f"SCHEMA: { schema_name } \n "
12749
128- columns = table_info .get ('columns' , [])
129- if columns :
130- schema_text += "Columns:\n "
131- for col in columns :
132- col_name = col .get ('column_name' , '' )
133- data_type = col .get ('data_type' , '' )
134- is_nullable = col .get ('is_nullable' , '' )
135- description = col .get ('description' , '' )
50+ # Add tables for this schema
51+ for i , table in enumerate (schema ['tables' ]):
52+ table_name = table ['name' ]
53+ table_desc = table .get ('description' , '' )
54+ row_count = table .get ('row_count' , 0 )
55+ row_count_text = f" ({ row_count } rows)" if row_count is not None else ""
56+
57+ # Determine if this is the last table in the schema
58+ is_last_table = i == len (schema ['tables' ]) - 1
59+ table_prefix = '└── ' if is_last_table else '├── '
60+
61+ # Add table line
62+ output += f"{ table_prefix } TABLE: { table_name } { row_count_text } \n "
63+
64+ # Add columns
65+ for j , column in enumerate (table ['columns' ]):
66+ column_name = column ['name' ]
67+ column_type = column ['type' ]
68+
69+ # Gather constraints for this column
70+ constraints = []
71+
72+ if not column ['nullable' ]:
73+ constraints .append ('NOT NULL' )
74+
75+ if 'PRIMARY KEY' in column .get ('constraints' , []):
76+ constraints .append ('PRIMARY KEY' )
77+
78+ if 'UNIQUE' in column .get ('constraints' , []):
79+ constraints .append ('UNIQUE' )
80+
81+ # Check if this column is part of a foreign key
82+ for fk in table .get ('foreign_keys' , []):
83+ if column_name in fk .get ('columns' , []):
84+ ref_schema = fk .get ('referenced_schema' , '' )
85+ ref_table = fk .get ('referenced_table' , '' )
86+ ref_cols = fk .get ('referenced_columns' , [])
87+ ref_col = ref_cols [fk .get ('columns' , []).index (column_name )] if ref_cols and column_name in fk .get ('columns' , []) else ''
88+
89+ constraints .append (f"FK → { ref_schema } .{ ref_table } ({ ref_col } )" )
13690
137- schema_text += f" - { col_name } ({ data_type } , nullable: { is_nullable } )"
138- if description :
139- schema_text += f" - { description } "
140- schema_text += "\n "
91+ # Format constraints text
92+ constraints_text = f", { ', ' .join (constraints )} " if constraints else ""
93+
94+ # Determine if this is the last column in the table
95+ is_last_column = j == len (table ['columns' ]) - 1
96+
97+ # Determine the appropriate prefix based on the nested level
98+ if is_last_table :
99+ column_prefix = ' └── ' if is_last_column else ' ├── '
100+ else :
101+ column_prefix = '│ └── ' if is_last_column else '│ ├── '
102+
103+ # Add column line
104+ output += f"{ column_prefix } { column_name } : { column_type } { constraints_text } \n "
105+
106+ # Add description if available
107+ if table_desc :
108+ description_prefix = ' ' if is_last_table else '│ '
109+ output += f"{ description_prefix } Description: { table_desc } \n "
110+
111+ # Add vertical spacing between tables (except for the last table)
112+ if not is_last_table :
113+ output += "│\n "
141114
142- schema_text += "\n "
115+ # Add vertical spacing between schemas
116+ if schema != db_structure ['schemas' ][- 1 ]:
117+ output += "\n "
143118
144- return schema_text
119+ return output
120+
121+ def clean_sql_query (sql_query ):
122+ """
123+ Clean a SQL query by properly handling escaped quotes and trailing backslashes.
124+
125+ Args:
126+ sql_query (str): The SQL query to clean
127+
128+ Returns:
129+ str: Cleaned SQL query
130+ """
131+ # Handle escaped quotes - need to do this character by character to avoid issues with trailing backslashes
132+ result = ""
133+ i = 0
134+
135+ while i < len (sql_query ):
136+ if sql_query [i ] == '\\ ' and i + 1 < len (sql_query ):
137+ # This is an escape sequence
138+ if sql_query [i + 1 ] == '"' :
139+ # Convert escaped quote to regular quote
140+ result += '"'
141+ i += 2 # Skip both the backslash and the quote
142+ elif sql_query [i + 1 ] == '\\ ' :
143+ # Handle escaped backslash
144+ result += '\\ '
145+ i += 2 # Skip both backslashes
146+ else :
147+ # Some other escape sequence, keep it
148+ result += sql_query [i :i + 2 ]
149+ i += 2
150+ else :
151+ # Regular character
152+ result += sql_query [i ]
153+ i += 1
154+
155+ # Remove any extraneous whitespace or newlines
156+ result = result .strip ()
157+
158+ return result
145159
146160async def generate_sql_with_anthropic (user_query , schema_text , anthropic_api_key ):
147161 """Generate SQL using Claude with response template prefilling."""
@@ -300,12 +314,21 @@ async def main():
300314 print (f"Error registering connection: { e } " )
301315 sys .exit (1 )
302316
303- # Fetch schema information
304- print ("Fetching database schema information..." )
305- schema_info = await fetch_schema_info (session , conn_id )
306- schema_text = format_schema_for_prompt (schema_info )
317+ # Fetch database hierarchy
318+ print ("Fetching database hierarchy information..." )
319+ db_hierarchy = await fetch_database_hierarchy (session , conn_id )
320+
321+ # Display the database hierarchy
322+ print ("\n Database Structure:" )
323+ print ("==================\n " )
324+ hierarchy_text = format_database_hierarchy (db_hierarchy )
325+ print (hierarchy_text )
326+ print ("\n ==================\n " )
327+
328+ # Use this hierarchy for Claude's prompt
329+ schema_text = hierarchy_text
307330
308- print (f"Retrieved information for { len ( schema_info ) } tables ." )
331+ print (f"Retrieved database structure information ." )
309332
310333 # Generate SQL using Claude with schema context
311334 print ("Generating SQL query with Claude..." )
@@ -321,6 +344,7 @@ async def main():
321344 print (f"------------" )
322345 print (explanation )
323346
347+ # Original query (as generated by Claude)
324348 print (f"\n Generated SQL query:" )
325349 print (f"------------------" )
326350 print (sql_query )
@@ -330,8 +354,16 @@ async def main():
330354 print ("No SQL query was generated. Exiting." )
331355 sys .exit (1 )
332356
357+ # Clean the SQL query before execution
358+ sql_query = clean_sql_query (sql_query )
359+
360+ # Show the cleaned query
361+ print (f"Cleaned SQL query:" )
362+ print (f"------------------" )
363+ print (sql_query )
364+ print (f"------------------\n " )
365+
333366 # Execute the generated SQL query
334- sql_query = codecs .decode (sql_query , 'unicode_escape' )
335367 print ("Executing SQL query..." )
336368 try :
337369 result = await session .call_tool (
0 commit comments