Skip to content

Commit 30170ca

Browse files
committed
Full Schema and Basic Prompt
1 parent 62935d2 commit 30170ca

File tree

9 files changed

+851
-128
lines changed

9 files changed

+851
-128
lines changed

.env.example

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
PG_MCP_URL=http://localhost:8000/sse
22
DATABASE_URL=postgresql://user:password@hostname:port/databasename
3-
ANTHROPIC_API_KEY=your-anthropic-api-key
3+
ANTHROPIC_API_KEY=your-anthropic-api-key
4+
GEMINI_API_KEY=your-gemini-api-key
File renamed without changes.

client/claude_cli.py renamed to example-clients/claude_cli.py

Lines changed: 157 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -3,145 +3,159 @@
33
import dotenv
44
import os
55
import sys
6-
import codecs
76
import json
8-
import uuid
97
import urllib.parse
108
import anthropic
119
from mcp import ClientSession
1210
from mcp.client.sse import sse_client
1311
from tabulate import tabulate
1412

15-
# This function is no longer needed as we're using the connect tool
16-
# Kept for reference but not used
17-
def postgres_connection_to_uuid(connection_string, namespace=uuid.NAMESPACE_URL):
18-
"""
19-
Convert a PostgreSQL connection string into a deterministic Version 5 UUID.
20-
Includes both connection credentials (netloc) and database name (path).
21-
"""
22-
# Make sure connection_string has proper protocol prefix
23-
if not connection_string.startswith("postgresql://"):
24-
connection_string = f"postgresql://{connection_string}"
25-
26-
# Parse the connection string
27-
parsed = urllib.parse.urlparse(connection_string)
28-
29-
# Extract the netloc (user:password@host:port) and path (database name)
30-
connection_id_string = parsed.netloc + parsed.path
31-
32-
# Create a Version 5 UUID (SHA-1 based)
33-
result_uuid = uuid.uuid5(namespace, connection_id_string)
34-
35-
return str(result_uuid)
36-
37-
async def fetch_schema_info(session, conn_id):
38-
"""Fetch database schema information from the MCP server."""
39-
schema_info = []
40-
41-
# First get all schemas
13+
async def fetch_database_hierarchy(session, conn_id):
14+
"""Fetch complete database structure from the MCP server using the new resource."""
4215
try:
43-
schemas_resource = f"pgmcp://{conn_id}/schemas"
44-
schemas_response = await session.read_resource(schemas_resource)
16+
# Get the complete database information
17+
db_resource = f"pgmcp://{conn_id}/"
18+
db_response = await session.read_resource(db_resource)
4519

46-
schemas_content = None
47-
if hasattr(schemas_response, 'content') and schemas_response.content:
48-
schemas_content = schemas_response.content
49-
elif hasattr(schemas_response, 'contents') and schemas_response.contents:
50-
schemas_content = schemas_response.contents
20+
content = None
21+
if hasattr(db_response, 'content') and db_response.content:
22+
content = db_response.content
23+
elif hasattr(db_response, 'contents') and db_response.contents:
24+
content = db_response.contents
5125

52-
if schemas_content:
53-
content = schemas_content[0]
54-
if hasattr(content, 'text'):
55-
schemas = json.loads(content.text)
56-
57-
# For each schema, get its tables
58-
for schema in schemas:
59-
schema_name = schema.get('schema_name')
60-
schema_description = schema.get('description', '')
61-
62-
# Fetch tables for this schema
63-
tables_resource = f"pgmcp://{conn_id}/schemas/{schema_name}/tables"
64-
tables_response = await session.read_resource(tables_resource)
65-
66-
tables_content = None
67-
if hasattr(tables_response, 'content') and tables_response.content:
68-
tables_content = tables_response.content
69-
elif hasattr(tables_response, 'contents') and tables_response.contents:
70-
tables_content = tables_response.contents
71-
72-
if tables_content:
73-
content = tables_content[0]
74-
if hasattr(content, 'text'):
75-
tables = json.loads(content.text)
76-
77-
# For each table, get its columns
78-
for table in tables:
79-
table_name = table.get('table_name')
80-
table_description = table.get('description', '')
81-
82-
# Fetch columns for this table
83-
columns_resource = f"pgmcp://{conn_id}/schemas/{schema_name}/tables/{table_name}/columns"
84-
columns_response = await session.read_resource(columns_resource)
85-
86-
columns = []
87-
columns_content = None
88-
if hasattr(columns_response, 'content') and columns_response.content:
89-
columns_content = columns_response.content
90-
elif hasattr(columns_response, 'contents') and columns_response.contents:
91-
columns_content = columns_response.contents
92-
93-
if columns_content:
94-
content = columns_content[0]
95-
if hasattr(content, 'text'):
96-
columns = json.loads(content.text)
97-
98-
# Add table with its columns to schema info
99-
schema_info.append({
100-
'schema': schema_name,
101-
'table': table_name,
102-
'description': table_description,
103-
'columns': columns
104-
})
26+
if content:
27+
content_item = content[0]
28+
if hasattr(content_item, 'text'):
29+
return json.loads(content_item.text)
10530

106-
return schema_info
31+
return None
10732
except Exception as e:
108-
print(f"Error fetching schema information: {e}")
109-
return []
33+
print(f"Error fetching database hierarchy: {e}")
34+
return None
11035

111-
def format_schema_for_prompt(schema_info):
112-
"""Format schema information as a string for the prompt."""
113-
if not schema_info:
114-
return "No schema information available."
36+
def format_database_hierarchy(db_structure):
37+
"""Format the database structure in a hierarchical console output."""
38+
if not db_structure or 'schemas' not in db_structure:
39+
return "No database structure available."
11540

116-
schema_text = "DATABASE SCHEMA:\n\n"
41+
output = "DATABASE HIERARCHY:\n\n"
11742

118-
for table_info in schema_info:
119-
schema_name = table_info.get('schema')
120-
table_name = table_info.get('table')
121-
description = table_info.get('description', '')
43+
for schema in db_structure['schemas']:
44+
schema_name = schema['name']
45+
schema_desc = schema.get('description', '')
12246

123-
schema_text += f"Table: {schema_name}.{table_name}"
124-
if description:
125-
schema_text += f" - {description}"
126-
schema_text += "\n"
47+
# Add schema header
48+
output += f"SCHEMA: {schema_name}\n"
12749

128-
columns = table_info.get('columns', [])
129-
if columns:
130-
schema_text += "Columns:\n"
131-
for col in columns:
132-
col_name = col.get('column_name', '')
133-
data_type = col.get('data_type', '')
134-
is_nullable = col.get('is_nullable', '')
135-
description = col.get('description', '')
50+
# Add tables for this schema
51+
for i, table in enumerate(schema['tables']):
52+
table_name = table['name']
53+
table_desc = table.get('description', '')
54+
row_count = table.get('row_count', 0)
55+
row_count_text = f" ({row_count} rows)" if row_count is not None else ""
56+
57+
# Determine if this is the last table in the schema
58+
is_last_table = i == len(schema['tables']) - 1
59+
table_prefix = '└── ' if is_last_table else '├── '
60+
61+
# Add table line
62+
output += f"{table_prefix}TABLE: {table_name}{row_count_text}\n"
63+
64+
# Add columns
65+
for j, column in enumerate(table['columns']):
66+
column_name = column['name']
67+
column_type = column['type']
68+
69+
# Gather constraints for this column
70+
constraints = []
71+
72+
if not column['nullable']:
73+
constraints.append('NOT NULL')
74+
75+
if 'PRIMARY KEY' in column.get('constraints', []):
76+
constraints.append('PRIMARY KEY')
77+
78+
if 'UNIQUE' in column.get('constraints', []):
79+
constraints.append('UNIQUE')
80+
81+
# Check if this column is part of a foreign key
82+
for fk in table.get('foreign_keys', []):
83+
if column_name in fk.get('columns', []):
84+
ref_schema = fk.get('referenced_schema', '')
85+
ref_table = fk.get('referenced_table', '')
86+
ref_cols = fk.get('referenced_columns', [])
87+
ref_col = ref_cols[fk.get('columns', []).index(column_name)] if ref_cols and column_name in fk.get('columns', []) else ''
88+
89+
constraints.append(f"FK → {ref_schema}.{ref_table}({ref_col})")
13690

137-
schema_text += f" - {col_name} ({data_type}, nullable: {is_nullable})"
138-
if description:
139-
schema_text += f" - {description}"
140-
schema_text += "\n"
91+
# Format constraints text
92+
constraints_text = f", {', '.join(constraints)}" if constraints else ""
93+
94+
# Determine if this is the last column in the table
95+
is_last_column = j == len(table['columns']) - 1
96+
97+
# Determine the appropriate prefix based on the nested level
98+
if is_last_table:
99+
column_prefix = ' └── ' if is_last_column else ' ├── '
100+
else:
101+
column_prefix = '│ └── ' if is_last_column else '│ ├── '
102+
103+
# Add column line
104+
output += f"{column_prefix}{column_name}: {column_type}{constraints_text}\n"
105+
106+
# Add description if available
107+
if table_desc:
108+
description_prefix = ' ' if is_last_table else '│ '
109+
output += f"{description_prefix}Description: {table_desc}\n"
110+
111+
# Add vertical spacing between tables (except for the last table)
112+
if not is_last_table:
113+
output += "│\n"
141114

142-
schema_text += "\n"
115+
# Add vertical spacing between schemas
116+
if schema != db_structure['schemas'][-1]:
117+
output += "\n"
143118

144-
return schema_text
119+
return output
120+
121+
def clean_sql_query(sql_query):
122+
"""
123+
Clean a SQL query by properly handling escaped quotes and trailing backslashes.
124+
125+
Args:
126+
sql_query (str): The SQL query to clean
127+
128+
Returns:
129+
str: Cleaned SQL query
130+
"""
131+
# Handle escaped quotes - need to do this character by character to avoid issues with trailing backslashes
132+
result = ""
133+
i = 0
134+
135+
while i < len(sql_query):
136+
if sql_query[i] == '\\' and i + 1 < len(sql_query):
137+
# This is an escape sequence
138+
if sql_query[i+1] == '"':
139+
# Convert escaped quote to regular quote
140+
result += '"'
141+
i += 2 # Skip both the backslash and the quote
142+
elif sql_query[i+1] == '\\':
143+
# Handle escaped backslash
144+
result += '\\'
145+
i += 2 # Skip both backslashes
146+
else:
147+
# Some other escape sequence, keep it
148+
result += sql_query[i:i+2]
149+
i += 2
150+
else:
151+
# Regular character
152+
result += sql_query[i]
153+
i += 1
154+
155+
# Remove any extraneous whitespace or newlines
156+
result = result.strip()
157+
158+
return result
145159

146160
async def generate_sql_with_anthropic(user_query, schema_text, anthropic_api_key):
147161
"""Generate SQL using Claude with response template prefilling."""
@@ -300,12 +314,21 @@ async def main():
300314
print(f"Error registering connection: {e}")
301315
sys.exit(1)
302316

303-
# Fetch schema information
304-
print("Fetching database schema information...")
305-
schema_info = await fetch_schema_info(session, conn_id)
306-
schema_text = format_schema_for_prompt(schema_info)
317+
# Fetch database hierarchy
318+
print("Fetching database hierarchy information...")
319+
db_hierarchy = await fetch_database_hierarchy(session, conn_id)
320+
321+
# Display the database hierarchy
322+
print("\nDatabase Structure:")
323+
print("==================\n")
324+
hierarchy_text = format_database_hierarchy(db_hierarchy)
325+
print(hierarchy_text)
326+
print("\n==================\n")
327+
328+
# Use this hierarchy for Claude's prompt
329+
schema_text = hierarchy_text
307330

308-
print(f"Retrieved information for {len(schema_info)} tables.")
331+
print(f"Retrieved database structure information.")
309332

310333
# Generate SQL using Claude with schema context
311334
print("Generating SQL query with Claude...")
@@ -321,6 +344,7 @@ async def main():
321344
print(f"------------")
322345
print(explanation)
323346

347+
# Original query (as generated by Claude)
324348
print(f"\nGenerated SQL query:")
325349
print(f"------------------")
326350
print(sql_query)
@@ -330,8 +354,16 @@ async def main():
330354
print("No SQL query was generated. Exiting.")
331355
sys.exit(1)
332356

357+
# Clean the SQL query before execution
358+
sql_query = clean_sql_query(sql_query)
359+
360+
# Show the cleaned query
361+
print(f"Cleaned SQL query:")
362+
print(f"------------------")
363+
print(sql_query)
364+
print(f"------------------\n")
365+
333366
# Execute the generated SQL query
334-
sql_query = codecs.decode(sql_query, 'unicode_escape')
335367
print("Executing SQL query...")
336368
try:
337369
result = await session.call_tool(

0 commit comments

Comments
 (0)