Skip to content

Commit 5deadcc

Browse files
authored
feat(duckdb): add support for community extension flags (#241)
Introduce runtime flags for community and unsigned extensions in DuckDBConfig, allowing users to enable specific features through SQLSpec. Ensure these flags are applied immediately after connection establishment, enhancing compatibility with older DuckDB builds.
1 parent 646a904 commit 5deadcc

File tree

4 files changed

+113
-5
lines changed

4 files changed

+113
-5
lines changed

docs/reference/adapters.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,15 @@ duckdb
570570
"SELECT * FROM 'https://example.com/data.parquet' LIMIT 10"
571571
)
572572
573+
**Community Extensions**:
574+
575+
DuckDBConfig accepts the runtime flags DuckDB expects for community/unsigned extensions via
576+
``pool_config`` (for example ``allow_community_extensions=True``,
577+
``allow_unsigned_extensions=True``, ``enable_external_access=True``). SQLSpec applies those
578+
options with ``SET`` statements immediately after establishing each connection, so even older
579+
DuckDB builds that do not recognize the options during ``duckdb.connect()`` will still enable the
580+
required permissions before extensions are installed.
581+
573582
**API Reference**:
574583

575584
.. autoclass:: DuckDBConfig

sqlspec/adapters/duckdb/config.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from collections.abc import Callable, Generator
2323

2424
from sqlspec.core import StatementConfig
25-
2625
__all__ = (
2726
"DuckDBConfig",
2827
"DuckDBConnectionParams",
@@ -31,10 +30,21 @@
3130
"DuckDBPoolParams",
3231
"DuckDBSecretConfig",
3332
)
33+
EXTENSION_FLAG_KEYS: "tuple[str, ...]" = (
34+
"allow_community_extensions",
35+
"allow_unsigned_extensions",
36+
"enable_external_access",
37+
)
3438

3539

3640
class DuckDBConnectionParams(TypedDict):
37-
"""DuckDB connection parameters."""
41+
"""DuckDB connection parameters.
42+
43+
Mirrors the keyword arguments accepted by duckdb.connect so callers can drive every DuckDB
44+
configuration switch directly through SQLSpec. All keys are optional and forwarded verbatim
45+
to DuckDB, either as top-level parameters or via the nested ``config`` dictionary when DuckDB
46+
expects them there.
47+
"""
3848

3949
database: NotRequired[str]
4050
read_only: NotRequired[bool]
@@ -75,7 +85,8 @@ class DuckDBConnectionParams(TypedDict):
7585
class DuckDBPoolParams(DuckDBConnectionParams):
7686
"""Complete pool configuration for DuckDB adapter.
7787
78-
Combines standardized pool parameters with DuckDB-specific connection parameters.
88+
Extends DuckDBConnectionParams with pool sizing and lifecycle settings so SQLSpec can manage
89+
per-thread DuckDB connections safely while honoring DuckDB's thread-safety constraints.
7990
"""
8091

8192
pool_min_size: NotRequired[int]
@@ -128,13 +139,16 @@ class DuckDBDriverFeatures(TypedDict):
128139
enable_uuid_conversion: Enable automatic UUID string conversion.
129140
When True (default), UUID strings are automatically converted to UUID objects.
130141
When False, UUID strings are treated as regular strings.
142+
extension_flags: Connection-level flags (e.g., allow_community_extensions) applied
143+
via SET statements immediately after connection creation.
131144
"""
132145

133146
extensions: NotRequired[Sequence[DuckDBExtensionConfig]]
134147
secrets: NotRequired[Sequence[DuckDBSecretConfig]]
135148
on_connection_create: NotRequired["Callable[[DuckDBConnection], DuckDBConnection | None]"]
136149
json_serializer: NotRequired["Callable[[Any], str]"]
137150
enable_uuid_conversion: NotRequired[bool]
151+
extension_flags: NotRequired[dict[str, Any]]
138152

139153

140154
class DuckDBConfig(SyncDatabaseConfig[DuckDBConnection, DuckDBConnectionPool, DuckDBDriver]):
@@ -222,13 +236,23 @@ def __init__(
222236
if pool_config.get("database") in {":memory:", ""}:
223237
pool_config["database"] = ":memory:shared_db"
224238

225-
processed_features = dict(driver_features) if driver_features else {}
239+
extension_flags: dict[str, Any] = {}
240+
for key in tuple(pool_config.keys()):
241+
if key in EXTENSION_FLAG_KEYS:
242+
extension_flags[key] = pool_config.pop(key) # type: ignore[misc]
243+
244+
processed_features: dict[str, Any] = dict(driver_features) if driver_features else {}
226245
user_connection_hook = cast(
227246
"Callable[[Any], None] | None", processed_features.pop("on_connection_create", None)
228247
)
229248
processed_features.setdefault("enable_uuid_conversion", True)
230249
serializer = processed_features.setdefault("json_serializer", to_json)
231250

251+
if extension_flags:
252+
existing_flags = cast("dict[str, Any]", processed_features.get("extension_flags", {}))
253+
merged_flags = {**existing_flags, **extension_flags}
254+
processed_features["extension_flags"] = merged_flags
255+
232256
local_observability = observability_config
233257
if user_connection_hook is not None:
234258

@@ -271,11 +295,17 @@ def _create_pool(self) -> DuckDBConnectionPool:
271295

272296
extensions = self.driver_features.get("extensions", None)
273297
secrets = self.driver_features.get("secrets", None)
298+
extension_flags = self.driver_features.get("extension_flags", None)
274299
extensions_dicts = [dict(ext) for ext in extensions] if extensions else None
275300
secrets_dicts = [dict(secret) for secret in secrets] if secrets else None
301+
extension_flags_dict = dict(extension_flags) if extension_flags else None
276302

277303
return DuckDBConnectionPool(
278-
connection_config=connection_config, extensions=extensions_dicts, secrets=secrets_dicts, **self.pool_config
304+
connection_config=connection_config,
305+
extensions=extensions_dicts,
306+
extension_flags=extension_flags_dict,
307+
secrets=secrets_dicts,
308+
**self.pool_config,
279309
)
280310

281311
def _close_pool(self) -> None:

sqlspec/adapters/duckdb/pool.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class DuckDBConnectionPool:
3939
"_connection_config",
4040
"_connection_times",
4141
"_created_connections",
42+
"_extension_flags",
4243
"_extensions",
4344
"_lock",
4445
"_on_connection_create",
@@ -52,6 +53,7 @@ def __init__(
5253
connection_config: "dict[str, Any]",
5354
pool_recycle_seconds: int = POOL_RECYCLE,
5455
extensions: "list[dict[str, Any]] | None" = None,
56+
extension_flags: "dict[str, Any] | None" = None,
5557
secrets: "list[dict[str, Any]] | None" = None,
5658
on_connection_create: "Callable[[DuckDBConnection], None] | None" = None,
5759
**kwargs: Any,
@@ -62,13 +64,15 @@ def __init__(
6264
connection_config: DuckDB connection configuration
6365
pool_recycle_seconds: Connection recycle time in seconds
6466
extensions: List of extensions to install/load
67+
extension_flags: Connection-level SET statements applied after creation
6568
secrets: List of secrets to create
6669
on_connection_create: Callback executed when connection is created
6770
**kwargs: Additional parameters ignored for compatibility
6871
"""
6972
self._connection_config = connection_config
7073
self._recycle = pool_recycle_seconds
7174
self._extensions = extensions or []
75+
self._extension_flags = extension_flags or {}
7276
self._secrets = secrets or []
7377
self._on_connection_create = on_connection_create
7478
self._thread_local = threading.local()
@@ -92,6 +96,8 @@ def _create_connection(self) -> DuckDBConnection:
9296

9397
connection = duckdb.connect(**connect_parameters)
9498

99+
self._apply_extension_flags(connection)
100+
95101
for ext_config in self._extensions:
96102
ext_name = ext_config.get("name")
97103
if not ext_name:
@@ -149,6 +155,33 @@ def _create_connection(self) -> DuckDBConnection:
149155

150156
return connection
151157

158+
def _apply_extension_flags(self, connection: DuckDBConnection) -> None:
159+
"""Apply connection-level extension flags via SET statements."""
160+
161+
if not self._extension_flags:
162+
return
163+
164+
for key, value in self._extension_flags.items():
165+
if not key or not key.replace("_", "").isalnum():
166+
continue
167+
168+
normalized = self._normalize_flag_value(value)
169+
try:
170+
connection.execute(f"SET {key} = {normalized}")
171+
except Exception as exc: # pragma: no cover - best-effort guard
172+
logger.debug("Failed to set DuckDB flag %s: %s", key, exc)
173+
174+
@staticmethod
175+
def _normalize_flag_value(value: Any) -> str:
176+
"""Convert Python value to DuckDB SET literal."""
177+
178+
if isinstance(value, bool):
179+
return "TRUE" if value else "FALSE"
180+
if isinstance(value, (int, float)):
181+
return str(value)
182+
escaped = str(value).replace("'", "''")
183+
return f"'{escaped}'"
184+
152185
def _get_thread_connection(self) -> DuckDBConnection:
153186
"""Get or create a connection for the current thread.
154187
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest
2+
3+
pytest.importorskip("duckdb", reason="DuckDB adapter requires duckdb package")
4+
5+
from sqlspec.adapters.duckdb import DuckDBConfig
6+
7+
8+
def test_duckdb_config_promotes_security_flags() -> None:
9+
config = DuckDBConfig(
10+
pool_config={
11+
"database": ":memory:",
12+
"allow_community_extensions": True,
13+
"allow_unsigned_extensions": False,
14+
"enable_external_access": True,
15+
}
16+
)
17+
18+
flags = config.driver_features.get("extension_flags")
19+
assert flags == {
20+
"allow_community_extensions": True,
21+
"allow_unsigned_extensions": False,
22+
"enable_external_access": True,
23+
}
24+
assert "allow_community_extensions" not in config.pool_config
25+
assert "allow_unsigned_extensions" not in config.pool_config
26+
assert "enable_external_access" not in config.pool_config
27+
28+
29+
def test_duckdb_config_merges_existing_extension_flags() -> None:
30+
config = DuckDBConfig(
31+
pool_config={"database": ":memory:", "allow_community_extensions": True},
32+
driver_features={"extension_flags": {"custom": "value"}},
33+
)
34+
35+
flags = config.driver_features.get("extension_flags")
36+
assert flags == {"custom": "value", "allow_community_extensions": True}

0 commit comments

Comments
 (0)