|
22 | 22 | from collections.abc import Callable, Generator |
23 | 23 |
|
24 | 24 | from sqlspec.core import StatementConfig |
25 | | - |
26 | 25 | __all__ = ( |
27 | 26 | "DuckDBConfig", |
28 | 27 | "DuckDBConnectionParams", |
|
31 | 30 | "DuckDBPoolParams", |
32 | 31 | "DuckDBSecretConfig", |
33 | 32 | ) |
| 33 | +EXTENSION_FLAG_KEYS: "tuple[str, ...]" = ( |
| 34 | + "allow_community_extensions", |
| 35 | + "allow_unsigned_extensions", |
| 36 | + "enable_external_access", |
| 37 | +) |
34 | 38 |
|
35 | 39 |
|
36 | 40 | class DuckDBConnectionParams(TypedDict): |
37 | | - """DuckDB connection parameters.""" |
| 41 | + """DuckDB connection parameters. |
| 42 | +
|
| 43 | + Mirrors the keyword arguments accepted by duckdb.connect so callers can drive every DuckDB |
| 44 | + configuration switch directly through SQLSpec. All keys are optional and forwarded verbatim |
| 45 | + to DuckDB, either as top-level parameters or via the nested ``config`` dictionary when DuckDB |
| 46 | + expects them there. |
| 47 | + """ |
38 | 48 |
|
39 | 49 | database: NotRequired[str] |
40 | 50 | read_only: NotRequired[bool] |
@@ -75,7 +85,8 @@ class DuckDBConnectionParams(TypedDict): |
75 | 85 | class DuckDBPoolParams(DuckDBConnectionParams): |
76 | 86 | """Complete pool configuration for DuckDB adapter. |
77 | 87 |
|
78 | | - Combines standardized pool parameters with DuckDB-specific connection parameters. |
| 88 | + Extends DuckDBConnectionParams with pool sizing and lifecycle settings so SQLSpec can manage |
| 89 | + per-thread DuckDB connections safely while honoring DuckDB's thread-safety constraints. |
79 | 90 | """ |
80 | 91 |
|
81 | 92 | pool_min_size: NotRequired[int] |
@@ -128,13 +139,16 @@ class DuckDBDriverFeatures(TypedDict): |
128 | 139 | enable_uuid_conversion: Enable automatic UUID string conversion. |
129 | 140 | When True (default), UUID strings are automatically converted to UUID objects. |
130 | 141 | When False, UUID strings are treated as regular strings. |
| 142 | + extension_flags: Connection-level flags (e.g., allow_community_extensions) applied |
| 143 | + via SET statements immediately after connection creation. |
131 | 144 | """ |
132 | 145 |
|
133 | 146 | extensions: NotRequired[Sequence[DuckDBExtensionConfig]] |
134 | 147 | secrets: NotRequired[Sequence[DuckDBSecretConfig]] |
135 | 148 | on_connection_create: NotRequired["Callable[[DuckDBConnection], DuckDBConnection | None]"] |
136 | 149 | json_serializer: NotRequired["Callable[[Any], str]"] |
137 | 150 | enable_uuid_conversion: NotRequired[bool] |
| 151 | + extension_flags: NotRequired[dict[str, Any]] |
138 | 152 |
|
139 | 153 |
|
140 | 154 | class DuckDBConfig(SyncDatabaseConfig[DuckDBConnection, DuckDBConnectionPool, DuckDBDriver]): |
@@ -222,13 +236,23 @@ def __init__( |
222 | 236 | if pool_config.get("database") in {":memory:", ""}: |
223 | 237 | pool_config["database"] = ":memory:shared_db" |
224 | 238 |
|
225 | | - processed_features = dict(driver_features) if driver_features else {} |
| 239 | + extension_flags: dict[str, Any] = {} |
| 240 | + for key in tuple(pool_config.keys()): |
| 241 | + if key in EXTENSION_FLAG_KEYS: |
| 242 | + extension_flags[key] = pool_config.pop(key) # type: ignore[misc] |
| 243 | + |
| 244 | + processed_features: dict[str, Any] = dict(driver_features) if driver_features else {} |
226 | 245 | user_connection_hook = cast( |
227 | 246 | "Callable[[Any], None] | None", processed_features.pop("on_connection_create", None) |
228 | 247 | ) |
229 | 248 | processed_features.setdefault("enable_uuid_conversion", True) |
230 | 249 | serializer = processed_features.setdefault("json_serializer", to_json) |
231 | 250 |
|
| 251 | + if extension_flags: |
| 252 | + existing_flags = cast("dict[str, Any]", processed_features.get("extension_flags", {})) |
| 253 | + merged_flags = {**existing_flags, **extension_flags} |
| 254 | + processed_features["extension_flags"] = merged_flags |
| 255 | + |
232 | 256 | local_observability = observability_config |
233 | 257 | if user_connection_hook is not None: |
234 | 258 |
|
@@ -271,11 +295,17 @@ def _create_pool(self) -> DuckDBConnectionPool: |
271 | 295 |
|
272 | 296 | extensions = self.driver_features.get("extensions", None) |
273 | 297 | secrets = self.driver_features.get("secrets", None) |
| 298 | + extension_flags = self.driver_features.get("extension_flags", None) |
274 | 299 | extensions_dicts = [dict(ext) for ext in extensions] if extensions else None |
275 | 300 | secrets_dicts = [dict(secret) for secret in secrets] if secrets else None |
| 301 | + extension_flags_dict = dict(extension_flags) if extension_flags else None |
276 | 302 |
|
277 | 303 | return DuckDBConnectionPool( |
278 | | - connection_config=connection_config, extensions=extensions_dicts, secrets=secrets_dicts, **self.pool_config |
| 304 | + connection_config=connection_config, |
| 305 | + extensions=extensions_dicts, |
| 306 | + extension_flags=extension_flags_dict, |
| 307 | + secrets=secrets_dicts, |
| 308 | + **self.pool_config, |
279 | 309 | ) |
280 | 310 |
|
281 | 311 | def _close_pool(self) -> None: |
|
0 commit comments