2121
2222import logging
2323from datetime import datetime , date
24+ from types import ModuleType
2425
2526from sqlalchemy import types as sqltypes
2627from sqlalchemy .engine import default , reflection
@@ -207,6 +208,12 @@ def initialize(self, connection):
207208 self .default_schema_name = \
208209 self ._get_default_schema_name (connection )
209210
211+ def set_isolation_level (self , dbapi_connection , level ):
212+ """
213+ For CrateDB, this is implemented as a noop.
214+ """
215+ pass
216+
210217 def do_rollback (self , connection ):
211218 # if any exception is raised by the dbapi, sqlalchemy by default
212219 # attempts to do a rollback crate doesn't support rollbacks.
@@ -225,7 +232,21 @@ def connect(self, host=None, port=None, *args, **kwargs):
225232 use_ssl = asbool (kwargs .pop ("ssl" , False ))
226233 if use_ssl :
227234 servers = ["https://" + server for server in servers ]
228- return self .dbapi .connect (servers = servers , ** kwargs )
235+
236+ is_module = isinstance (self .dbapi , ModuleType )
237+ if is_module :
238+ driver_name = self .dbapi .__name__
239+ else :
240+ driver_name = self .dbapi .__class__ .__name__
241+ if driver_name == "crate.client" :
242+ if "database" in kwargs :
243+ del kwargs ["database" ]
244+ return self .dbapi .connect (servers = servers , ** kwargs )
245+ elif driver_name in ["psycopg" , "PsycopgAdaptDBAPI" , "AsyncAdapt_asyncpg_dbapi" ]:
246+ return self .dbapi .connect (host = host , port = port , ** kwargs )
247+ else :
248+ raise ValueError (f"Unknown driver variant: { driver_name } " )
249+
229250 return self .dbapi .connect (** kwargs )
230251
231252 def _get_default_schema_name (self , connection ):
@@ -271,11 +292,11 @@ def get_schema_names(self, connection, **kw):
271292 def get_table_names (self , connection , schema = None , ** kw ):
272293 if schema is None :
273294 schema = self ._get_effective_schema_name (connection )
274- cursor = connection .exec_driver_sql (
295+ cursor = connection .exec_driver_sql (self . _format_query (
275296 "SELECT table_name FROM information_schema.tables "
276297 "WHERE {0} = ? "
277298 "AND table_type = 'BASE TABLE' "
278- "ORDER BY table_name ASC, {0} ASC" .format (self .schema_column ),
299+ "ORDER BY table_name ASC, {0} ASC" ) .format (self .schema_column ),
279300 (schema or self .default_schema_name , )
280301 )
281302 return [row [0 ] for row in cursor .fetchall ()]
@@ -297,7 +318,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
297318 "AND column_name !~ ?" \
298319 .format (self .schema_column )
299320 cursor = connection .exec_driver_sql (
300- query ,
321+ self . _format_query ( query ) ,
301322 (table_name ,
302323 schema or self .default_schema_name ,
303324 r"(.*)\[\'(.*)\'\]" ) # regex to filter subscript
@@ -336,7 +357,7 @@ def result_fun(result):
336357 return set (rows [0 ] if rows else [])
337358
338359 pk_result = engine .exec_driver_sql (
339- query ,
360+ self . _format_query ( query ) ,
340361 (table_name , schema or self .default_schema_name )
341362 )
342363 pks = result_fun (pk_result )
@@ -377,6 +398,17 @@ def has_ilike_operator(self):
377398 server_version_info = self .server_version_info
378399 return server_version_info is not None and server_version_info >= (4 , 1 , 0 )
379400
401+ def _format_query (self , query ):
402+ """
403+ When using the PostgreSQL protocol with drivers `psycopg` or `asyncpg`,
404+ the paramstyle is not `qmark`, but `pyformat`.
405+
406+ TODO: Review: Is it legit and sane? Are there alternatives?
407+ """
408+ if self .paramstyle == "pyformat" :
409+ query = query .replace ("= ?" , "= %s" ).replace ("!~ ?" , "!~ %s" )
410+ return query
411+
380412
381413class DateTrunc (functions .GenericFunction ):
382414 name = "date_trunc"
0 commit comments