Skip to content

Commit cd88b4a

Browse files
authored
fix: update signature namespace and set configs consistently (#236)
Update config consistency and update signature namespaces
1 parent f1729a4 commit cd88b4a

File tree

14 files changed

+338
-269
lines changed

14 files changed

+338
-269
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ alloydb = ["google-cloud-alloydb-connector"]
2424
asyncmy = ["asyncmy"]
2525
asyncpg = ["asyncpg"]
2626
attrs = ["attrs", "cattrs"]
27-
bigquery = ["google-cloud-bigquery"]
27+
bigquery = ["google-cloud-bigquery", "google-cloud-storage"]
2828
cli = ["rich-click"]
2929
cloud-sql = ["cloud-sql-python-connector"]
3030
duckdb = ["duckdb"]

sqlspec/adapters/adbc/config.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from typing_extensions import NotRequired
99

1010
from sqlspec.adapters.adbc._types import AdbcConnection
11-
from sqlspec.adapters.adbc.driver import AdbcCursor, AdbcDriver, get_adbc_statement_config
11+
from sqlspec.adapters.adbc.driver import AdbcCursor, AdbcDriver, AdbcExceptionHandler, get_adbc_statement_config
1212
from sqlspec.config import ADKConfig, FastAPIConfig, FlaskConfig, LitestarConfig, NoPoolSyncConfig, StarletteConfig
1313
from sqlspec.core import StatementConfig
1414
from sqlspec.exceptions import ImproperConfigurationError
1515
from sqlspec.utils.module_loader import import_string
16+
from sqlspec.utils.serializers import to_json
1617

1718
if TYPE_CHECKING:
1819
from collections.abc import Generator
@@ -140,20 +141,12 @@ def __init__(
140141
detected_dialect = str(self._get_dialect() or "sqlite")
141142
statement_config = get_adbc_statement_config(detected_dialect)
142143

143-
from sqlspec.utils.serializers import to_json
144+
processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}
145+
json_serializer = processed_driver_features.setdefault("json_serializer", to_json)
146+
processed_driver_features.setdefault("enable_cast_detection", True)
147+
processed_driver_features.setdefault("strict_type_coercion", False)
148+
processed_driver_features.setdefault("arrow_extension_types", True)
144149

145-
if driver_features is None:
146-
driver_features = {}
147-
if "json_serializer" not in driver_features:
148-
driver_features["json_serializer"] = to_json
149-
if "enable_cast_detection" not in driver_features:
150-
driver_features["enable_cast_detection"] = True
151-
if "strict_type_coercion" not in driver_features:
152-
driver_features["strict_type_coercion"] = False
153-
if "arrow_extension_types" not in driver_features:
154-
driver_features["arrow_extension_types"] = True
155-
156-
json_serializer = driver_features.get("json_serializer")
157150
if json_serializer is not None:
158151
parameter_config = statement_config.parameter_config
159152
previous_list_converter = parameter_config.type_coercion_map.get(list)
@@ -172,7 +165,7 @@ def __init__(
172165
connection_config=self.connection_config,
173166
migration_config=migration_config,
174167
statement_config=statement_config,
175-
driver_features=dict(driver_features),
168+
driver_features=processed_driver_features,
176169
bind_key=bind_key,
177170
extension_config=extension_config,
178171
)
@@ -420,13 +413,18 @@ def _get_connection_config_dict(self) -> dict[str, Any]:
420413

421414
return config
422415

423-
def get_signature_namespace(self) -> "dict[str, type[Any]]":
416+
def get_signature_namespace(self) -> "dict[str, Any]":
424417
"""Get the signature namespace for types.
425418
426419
Returns:
427420
Dictionary mapping type names to types.
428421
"""
429-
430422
namespace = super().get_signature_namespace()
431-
namespace.update({"AdbcConnection": AdbcConnection, "AdbcCursor": AdbcCursor})
423+
namespace.update({
424+
"AdbcConnection": AdbcConnection,
425+
"AdbcConnectionParams": AdbcConnectionParams,
426+
"AdbcCursor": AdbcCursor,
427+
"AdbcDriver": AdbcDriver,
428+
"AdbcExceptionHandler": AdbcExceptionHandler,
429+
})
432430
return namespace

sqlspec/adapters/aiosqlite/config.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,19 @@
77
from typing_extensions import NotRequired
88

99
from sqlspec.adapters.aiosqlite._types import AiosqliteConnection
10-
from sqlspec.adapters.aiosqlite.driver import AiosqliteCursor, AiosqliteDriver, aiosqlite_statement_config
10+
from sqlspec.adapters.aiosqlite.driver import (
11+
AiosqliteCursor,
12+
AiosqliteDriver,
13+
AiosqliteExceptionHandler,
14+
aiosqlite_statement_config,
15+
)
1116
from sqlspec.adapters.aiosqlite.pool import (
1217
AiosqliteConnectionPool,
1318
AiosqliteConnectTimeoutError,
1419
AiosqlitePoolClosedError,
1520
AiosqlitePoolConnection,
1621
)
22+
from sqlspec.adapters.sqlite._type_handlers import register_type_handlers
1723
from sqlspec.config import ADKConfig, AsyncDatabaseConfig, FastAPIConfig, FlaskConfig, LitestarConfig, StarletteConfig
1824
from sqlspec.utils.serializers import from_json, to_json
1925

@@ -117,20 +123,11 @@ def __init__(
117123
config_dict["uri"] = True
118124

119125
processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}
120-
121-
if "enable_custom_adapters" not in processed_driver_features:
122-
processed_driver_features["enable_custom_adapters"] = True
123-
124-
if "json_serializer" not in processed_driver_features:
125-
processed_driver_features["json_serializer"] = to_json
126-
127-
if "json_deserializer" not in processed_driver_features:
128-
processed_driver_features["json_deserializer"] = from_json
126+
processed_driver_features.setdefault("enable_custom_adapters", True)
127+
json_serializer = processed_driver_features.setdefault("json_serializer", to_json)
128+
json_deserializer = processed_driver_features.setdefault("json_deserializer", from_json)
129129

130130
base_statement_config = statement_config or aiosqlite_statement_config
131-
132-
json_serializer = processed_driver_features.get("json_serializer")
133-
json_deserializer = processed_driver_features.get("json_deserializer")
134131
if json_serializer is not None:
135132
parameter_config = base_statement_config.parameter_config.with_json_serializers(
136133
json_serializer, deserializer=json_deserializer
@@ -250,8 +247,6 @@ def _register_type_adapters(self) -> None:
250247
sync adapter, so this shares the implementation.
251248
"""
252249
if self.driver_features.get("enable_custom_adapters", False):
253-
from sqlspec.adapters.sqlite._type_handlers import register_type_handlers
254-
255250
register_type_handlers(
256251
json_serializer=self.driver_features.get("json_serializer"),
257252
json_deserializer=self.driver_features.get("json_deserializer"),
@@ -283,7 +278,7 @@ async def provide_pool(self) -> AiosqliteConnectionPool:
283278
self.pool_instance = await self.create_pool()
284279
return self.pool_instance
285280

286-
def get_signature_namespace(self) -> "dict[str, type[Any]]":
281+
def get_signature_namespace(self) -> "dict[str, Any]":
287282
"""Get the signature namespace for aiosqlite types.
288283
289284
Returns:
@@ -292,11 +287,16 @@ def get_signature_namespace(self) -> "dict[str, type[Any]]":
292287
namespace = super().get_signature_namespace()
293288
namespace.update({
294289
"AiosqliteConnection": AiosqliteConnection,
290+
"AiosqliteConnectionParams": AiosqliteConnectionParams,
295291
"AiosqliteConnectionPool": AiosqliteConnectionPool,
296292
"AiosqliteConnectTimeoutError": AiosqliteConnectTimeoutError,
297293
"AiosqliteCursor": AiosqliteCursor,
294+
"AiosqliteDriver": AiosqliteDriver,
295+
"AiosqliteDriverFeatures": AiosqliteDriverFeatures,
296+
"AiosqliteExceptionHandler": AiosqliteExceptionHandler,
298297
"AiosqlitePoolClosedError": AiosqlitePoolClosedError,
299298
"AiosqlitePoolConnection": AiosqlitePoolConnection,
299+
"AiosqlitePoolParams": AiosqlitePoolParams,
300300
})
301301
return namespace
302302

sqlspec/adapters/asyncmy/config.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sqlspec.adapters.asyncmy.driver import (
1515
AsyncmyCursor,
1616
AsyncmyDriver,
17+
AsyncmyExceptionHandler,
1718
asyncmy_statement_config,
1819
build_asyncmy_statement_config,
1920
)
@@ -121,10 +122,8 @@ def __init__(
121122
extras = processed_pool_config.pop("extra")
122123
processed_pool_config.update(extras)
123124

124-
if "host" not in processed_pool_config:
125-
processed_pool_config["host"] = "localhost"
126-
if "port" not in processed_pool_config:
127-
processed_pool_config["port"] = 3306
125+
processed_pool_config.setdefault("host", "localhost")
126+
processed_pool_config.setdefault("port", 3306)
128127

129128
processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}
130129
serializer = processed_driver_features.setdefault("json_serializer", to_json)
@@ -221,7 +220,7 @@ async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool": # pyright: i
221220
self.pool_instance = await self.create_pool()
222221
return self.pool_instance
223222

224-
def get_signature_namespace(self) -> "dict[str, type[Any]]":
223+
def get_signature_namespace(self) -> "dict[str, Any]":
225224
"""Get the signature namespace for Asyncmy types.
226225
227226
Returns:
@@ -231,7 +230,12 @@ def get_signature_namespace(self) -> "dict[str, type[Any]]":
231230
namespace = super().get_signature_namespace()
232231
namespace.update({
233232
"AsyncmyConnection": AsyncmyConnection,
234-
"AsyncmyPool": AsyncmyPool,
233+
"AsyncmyConnectionParams": AsyncmyConnectionParams,
235234
"AsyncmyCursor": AsyncmyCursor,
235+
"AsyncmyDriver": AsyncmyDriver,
236+
"AsyncmyDriverFeatures": AsyncmyDriverFeatures,
237+
"AsyncmyExceptionHandler": AsyncmyExceptionHandler,
238+
"AsyncmyPool": AsyncmyPool,
239+
"AsyncmyPoolParams": AsyncmyPoolParams,
236240
})
237241
return namespace

sqlspec/adapters/asyncpg/config.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
from asyncpg.pool import Pool, PoolConnectionProxy, PoolConnectionProxyMeta
1212
from typing_extensions import NotRequired
1313

14-
from sqlspec.adapters.asyncpg._types import AsyncpgConnection
14+
from sqlspec.adapters.asyncpg._type_handlers import register_json_codecs, register_pgvector_support
15+
from sqlspec.adapters.asyncpg._types import AsyncpgConnection, AsyncpgPool
1516
from sqlspec.adapters.asyncpg.driver import (
1617
AsyncpgCursor,
1718
AsyncpgDriver,
19+
AsyncpgExceptionHandler,
1820
asyncpg_statement_config,
1921
build_asyncpg_statement_config,
2022
)
@@ -329,8 +331,7 @@ async def _create_pool(self) -> "Pool[Record]":
329331
elif self.driver_features.get("enable_alloydb", False):
330332
self._setup_alloydb_connector(config)
331333

332-
if "init" not in config:
333-
config["init"] = self._init_connection
334+
config.setdefault("init", self._init_connection)
334335

335336
return await asyncpg_create_pool(**config)
336337

@@ -341,17 +342,13 @@ async def _init_connection(self, connection: "AsyncpgConnection") -> None:
341342
connection: AsyncPG connection to initialize.
342343
"""
343344
if self.driver_features.get("enable_json_codecs", True):
344-
from sqlspec.adapters.asyncpg._type_handlers import register_json_codecs
345-
346345
await register_json_codecs(
347346
connection,
348347
encoder=self.driver_features.get("json_serializer", to_json),
349348
decoder=self.driver_features.get("json_deserializer", from_json),
350349
)
351350

352351
if self.driver_features.get("enable_pgvector", False):
353-
from sqlspec.adapters.asyncpg._type_handlers import register_pgvector_support
354-
355352
await register_pgvector_support(connection)
356353

357354
async def _close_pool(self) -> None:
@@ -432,7 +429,7 @@ async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool[Record]":
432429
self.pool_instance = await self.create_pool()
433430
return self.pool_instance
434431

435-
def get_signature_namespace(self) -> "dict[str, type[Any]]":
432+
def get_signature_namespace(self) -> "dict[str, Any]":
436433
"""Get the signature namespace for AsyncPG types.
437434
438435
This provides all AsyncPG-specific types that Litestar needs to recognize
@@ -450,7 +447,12 @@ def get_signature_namespace(self) -> "dict[str, type[Any]]":
450447
"PoolConnectionProxyMeta": PoolConnectionProxyMeta,
451448
"ConnectionMeta": ConnectionMeta,
452449
"Record": Record,
453-
"AsyncpgConnection": AsyncpgConnection, # type: ignore[dict-item]
450+
"AsyncpgConnection": AsyncpgConnection,
451+
"AsyncpgConnectionConfig": AsyncpgConnectionConfig,
454452
"AsyncpgCursor": AsyncpgCursor,
453+
"AsyncpgDriver": AsyncpgDriver,
454+
"AsyncpgExceptionHandler": AsyncpgExceptionHandler,
455+
"AsyncpgPool": AsyncpgPool,
456+
"AsyncpgPoolConfig": AsyncpgPoolConfig,
455457
})
456458
return namespace

sqlspec/adapters/bigquery/config.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from typing_extensions import NotRequired
99

1010
from sqlspec.adapters.bigquery._types import BigQueryConnection
11-
from sqlspec.adapters.bigquery.driver import BigQueryCursor, BigQueryDriver, build_bigquery_statement_config
11+
from sqlspec.adapters.bigquery.driver import (
12+
BigQueryCursor,
13+
BigQueryDriver,
14+
BigQueryExceptionHandler,
15+
build_bigquery_statement_config,
16+
)
1217
from sqlspec.config import ADKConfig, FastAPIConfig, FlaskConfig, LitestarConfig, NoPoolSyncConfig, StarletteConfig
1318
from sqlspec.exceptions import ImproperConfigurationError
1419
from sqlspec.typing import Empty
@@ -134,10 +139,7 @@ def __init__(
134139
self.connection_config.update(extras)
135140

136141
self.driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}
137-
138-
if "enable_uuid_conversion" not in self.driver_features:
139-
self.driver_features["enable_uuid_conversion"] = True
140-
142+
self.driver_features.setdefault("enable_uuid_conversion", True)
141143
serializer = self.driver_features.setdefault("json_serializer", to_json)
142144

143145
self._connection_instance: BigQueryConnection | None = self.driver_features.get("connection_instance")
@@ -263,13 +265,19 @@ def provide_session(
263265
)
264266
yield driver
265267

266-
def get_signature_namespace(self) -> "dict[str, type[Any]]":
268+
def get_signature_namespace(self) -> "dict[str, Any]":
267269
"""Get the signature namespace for BigQuery types.
268270
269271
Returns:
270272
Dictionary mapping type names to types.
271273
"""
272274

273275
namespace = super().get_signature_namespace()
274-
namespace.update({"BigQueryConnection": BigQueryConnection, "BigQueryCursor": BigQueryCursor})
276+
namespace.update({
277+
"BigQueryConnection": BigQueryConnection,
278+
"BigQueryConnectionParams": BigQueryConnectionParams,
279+
"BigQueryCursor": BigQueryCursor,
280+
"BigQueryDriver": BigQueryDriver,
281+
"BigQueryExceptionHandler": BigQueryExceptionHandler,
282+
})
275283
return namespace

sqlspec/adapters/duckdb/config.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from typing_extensions import NotRequired
88

99
from sqlspec.adapters.duckdb._types import DuckDBConnection
10-
from sqlspec.adapters.duckdb.driver import DuckDBCursor, DuckDBDriver, build_duckdb_statement_config
10+
from sqlspec.adapters.duckdb.driver import (
11+
DuckDBCursor,
12+
DuckDBDriver,
13+
DuckDBExceptionHandler,
14+
build_duckdb_statement_config,
15+
)
1116
from sqlspec.adapters.duckdb.pool import DuckDBConnectionPool
1217
from sqlspec.config import ADKConfig, FastAPIConfig, FlaskConfig, LitestarConfig, StarletteConfig, SyncDatabaseConfig
1318
from sqlspec.utils.serializers import to_json
@@ -209,8 +214,7 @@ def __init__(
209214
"""
210215
if pool_config is None:
211216
pool_config = {}
212-
if "database" not in pool_config:
213-
pool_config["database"] = ":memory:shared_db"
217+
pool_config.setdefault("database", ":memory:shared_db")
214218

215219
if pool_config.get("database") in {":memory:", ""}:
216220
pool_config["database"] = ":memory:shared_db"
@@ -331,7 +335,7 @@ def provide_session(
331335
)
332336
yield driver
333337

334-
def get_signature_namespace(self) -> "dict[str, type[Any]]":
338+
def get_signature_namespace(self) -> "dict[str, Any]":
335339
"""Get the signature namespace for DuckDB types.
336340
337341
This provides all DuckDB-specific types that Litestar needs to recognize
@@ -342,5 +346,16 @@ def get_signature_namespace(self) -> "dict[str, type[Any]]":
342346
"""
343347

344348
namespace = super().get_signature_namespace()
345-
namespace.update({"DuckDBConnection": DuckDBConnection, "DuckDBCursor": DuckDBCursor})
349+
namespace.update({
350+
"DuckDBConnection": DuckDBConnection,
351+
"DuckDBConnectionParams": DuckDBConnectionParams,
352+
"DuckDBConnectionPool": DuckDBConnectionPool,
353+
"DuckDBCursor": DuckDBCursor,
354+
"DuckDBDriver": DuckDBDriver,
355+
"DuckDBDriverFeatures": DuckDBDriverFeatures,
356+
"DuckDBExceptionHandler": DuckDBExceptionHandler,
357+
"DuckDBExtensionConfig": DuckDBExtensionConfig,
358+
"DuckDBPoolParams": DuckDBPoolParams,
359+
"DuckDBSecretConfig": DuckDBSecretConfig,
360+
})
346361
return namespace

0 commit comments

Comments
 (0)