Skip to content

Commit 8d7f4e8

Browse files
authored
Merge branch 'main' into 173_usage_configuration
2 parents 93250d4 + 5deadcc commit 8d7f4e8

File tree

4 files changed

+113
-5
lines changed

4 files changed

+113
-5
lines changed

docs/reference/adapters.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,15 @@ duckdb
570570
"SELECT * FROM 'https://example.com/data.parquet' LIMIT 10"
571571
)
572572
573+
**Community Extensions**:
574+
575+
DuckDBConfig accepts the runtime flags DuckDB expects for community/unsigned extensions via
576+
``pool_config`` (for example ``allow_community_extensions=True``,
577+
``allow_unsigned_extensions=True``, ``enable_external_access=True``). SQLSpec applies those
578+
options with ``SET`` statements immediately after establishing each connection, so even older
579+
DuckDB builds that do not recognize the options during ``duckdb.connect()`` will still enable the
580+
required permissions before extensions are installed.
581+
573582
**API Reference**:
574583

575584
.. autoclass:: DuckDBConfig

sqlspec/adapters/duckdb/config.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from collections.abc import Callable, Generator
2323

2424
from sqlspec.core import StatementConfig
25-
2625
__all__ = (
2726
"DuckDBConfig",
2827
"DuckDBConnectionParams",
@@ -31,10 +30,21 @@
3130
"DuckDBPoolParams",
3231
"DuckDBSecretConfig",
3332
)
33+
EXTENSION_FLAG_KEYS: "tuple[str, ...]" = (
34+
"allow_community_extensions",
35+
"allow_unsigned_extensions",
36+
"enable_external_access",
37+
)
3438

3539

3640
class DuckDBConnectionParams(TypedDict):
37-
"""DuckDB connection parameters."""
41+
"""DuckDB connection parameters.
42+
43+
Mirrors the keyword arguments accepted by duckdb.connect so callers can drive every DuckDB
44+
configuration switch directly through SQLSpec. All keys are optional and forwarded verbatim
45+
to DuckDB, either as top-level parameters or via the nested ``config`` dictionary when DuckDB
46+
expects them there.
47+
"""
3848

3949
database: NotRequired[str]
4050
read_only: NotRequired[bool]
@@ -75,7 +85,8 @@ class DuckDBConnectionParams(TypedDict):
7585
class DuckDBPoolParams(DuckDBConnectionParams):
7686
"""Complete pool configuration for DuckDB adapter.
7787
78-
Combines standardized pool parameters with DuckDB-specific connection parameters.
88+
Extends DuckDBConnectionParams with pool sizing and lifecycle settings so SQLSpec can manage
89+
per-thread DuckDB connections safely while honoring DuckDB's thread-safety constraints.
7990
"""
8091

8192
pool_min_size: NotRequired[int]
@@ -128,13 +139,16 @@ class DuckDBDriverFeatures(TypedDict):
128139
enable_uuid_conversion: Enable automatic UUID string conversion.
129140
When True (default), UUID strings are automatically converted to UUID objects.
130141
When False, UUID strings are treated as regular strings.
142+
extension_flags: Connection-level flags (e.g., allow_community_extensions) applied
143+
via SET statements immediately after connection creation.
131144
"""
132145

133146
extensions: NotRequired[Sequence[DuckDBExtensionConfig]]
134147
secrets: NotRequired[Sequence[DuckDBSecretConfig]]
135148
on_connection_create: NotRequired["Callable[[DuckDBConnection], DuckDBConnection | None]"]
136149
json_serializer: NotRequired["Callable[[Any], str]"]
137150
enable_uuid_conversion: NotRequired[bool]
151+
extension_flags: NotRequired[dict[str, Any]]
138152

139153

140154
class DuckDBConfig(SyncDatabaseConfig[DuckDBConnection, DuckDBConnectionPool, DuckDBDriver]):
@@ -222,13 +236,23 @@ def __init__(
222236
if pool_config.get("database") in {":memory:", ""}:
223237
pool_config["database"] = ":memory:shared_db"
224238

225-
processed_features = dict(driver_features) if driver_features else {}
239+
extension_flags: dict[str, Any] = {}
240+
for key in tuple(pool_config.keys()):
241+
if key in EXTENSION_FLAG_KEYS:
242+
extension_flags[key] = pool_config.pop(key) # type: ignore[misc]
243+
244+
processed_features: dict[str, Any] = dict(driver_features) if driver_features else {}
226245
user_connection_hook = cast(
227246
"Callable[[Any], None] | None", processed_features.pop("on_connection_create", None)
228247
)
229248
processed_features.setdefault("enable_uuid_conversion", True)
230249
serializer = processed_features.setdefault("json_serializer", to_json)
231250

251+
if extension_flags:
252+
existing_flags = cast("dict[str, Any]", processed_features.get("extension_flags", {}))
253+
merged_flags = {**existing_flags, **extension_flags}
254+
processed_features["extension_flags"] = merged_flags
255+
232256
local_observability = observability_config
233257
if user_connection_hook is not None:
234258

@@ -271,11 +295,17 @@ def _create_pool(self) -> DuckDBConnectionPool:
271295

272296
extensions = self.driver_features.get("extensions", None)
273297
secrets = self.driver_features.get("secrets", None)
298+
extension_flags = self.driver_features.get("extension_flags", None)
274299
extensions_dicts = [dict(ext) for ext in extensions] if extensions else None
275300
secrets_dicts = [dict(secret) for secret in secrets] if secrets else None
301+
extension_flags_dict = dict(extension_flags) if extension_flags else None
276302

277303
return DuckDBConnectionPool(
278-
connection_config=connection_config, extensions=extensions_dicts, secrets=secrets_dicts, **self.pool_config
304+
connection_config=connection_config,
305+
extensions=extensions_dicts,
306+
extension_flags=extension_flags_dict,
307+
secrets=secrets_dicts,
308+
**self.pool_config,
279309
)
280310

281311
def _close_pool(self) -> None:

sqlspec/adapters/duckdb/pool.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class DuckDBConnectionPool:
3939
"_connection_config",
4040
"_connection_times",
4141
"_created_connections",
42+
"_extension_flags",
4243
"_extensions",
4344
"_lock",
4445
"_on_connection_create",
@@ -52,6 +53,7 @@ def __init__(
5253
connection_config: "dict[str, Any]",
5354
pool_recycle_seconds: int = POOL_RECYCLE,
5455
extensions: "list[dict[str, Any]] | None" = None,
56+
extension_flags: "dict[str, Any] | None" = None,
5557
secrets: "list[dict[str, Any]] | None" = None,
5658
on_connection_create: "Callable[[DuckDBConnection], None] | None" = None,
5759
**kwargs: Any,
@@ -62,13 +64,15 @@ def __init__(
6264
connection_config: DuckDB connection configuration
6365
pool_recycle_seconds: Connection recycle time in seconds
6466
extensions: List of extensions to install/load
67+
extension_flags: Connection-level SET statements applied after creation
6568
secrets: List of secrets to create
6669
on_connection_create: Callback executed when connection is created
6770
**kwargs: Additional parameters ignored for compatibility
6871
"""
6972
self._connection_config = connection_config
7073
self._recycle = pool_recycle_seconds
7174
self._extensions = extensions or []
75+
self._extension_flags = extension_flags or {}
7276
self._secrets = secrets or []
7377
self._on_connection_create = on_connection_create
7478
self._thread_local = threading.local()
@@ -92,6 +96,8 @@ def _create_connection(self) -> DuckDBConnection:
9296

9397
connection = duckdb.connect(**connect_parameters)
9498

99+
self._apply_extension_flags(connection)
100+
95101
for ext_config in self._extensions:
96102
ext_name = ext_config.get("name")
97103
if not ext_name:
@@ -149,6 +155,33 @@ def _create_connection(self) -> DuckDBConnection:
149155

150156
return connection
151157

158+
def _apply_extension_flags(self, connection: DuckDBConnection) -> None:
159+
"""Apply connection-level extension flags via SET statements."""
160+
161+
if not self._extension_flags:
162+
return
163+
164+
for key, value in self._extension_flags.items():
165+
if not key or not key.replace("_", "").isalnum():
166+
continue
167+
168+
normalized = self._normalize_flag_value(value)
169+
try:
170+
connection.execute(f"SET {key} = {normalized}")
171+
except Exception as exc: # pragma: no cover - best-effort guard
172+
logger.debug("Failed to set DuckDB flag %s: %s", key, exc)
173+
174+
@staticmethod
175+
def _normalize_flag_value(value: Any) -> str:
176+
"""Convert Python value to DuckDB SET literal."""
177+
178+
if isinstance(value, bool):
179+
return "TRUE" if value else "FALSE"
180+
if isinstance(value, (int, float)):
181+
return str(value)
182+
escaped = str(value).replace("'", "''")
183+
return f"'{escaped}'"
184+
152185
def _get_thread_connection(self) -> DuckDBConnection:
153186
"""Get or create a connection for the current thread.
154187
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest
2+
3+
pytest.importorskip("duckdb", reason="DuckDB adapter requires duckdb package")
4+
5+
from sqlspec.adapters.duckdb import DuckDBConfig
6+
7+
8+
def test_duckdb_config_promotes_security_flags() -> None:
9+
config = DuckDBConfig(
10+
pool_config={
11+
"database": ":memory:",
12+
"allow_community_extensions": True,
13+
"allow_unsigned_extensions": False,
14+
"enable_external_access": True,
15+
}
16+
)
17+
18+
flags = config.driver_features.get("extension_flags")
19+
assert flags == {
20+
"allow_community_extensions": True,
21+
"allow_unsigned_extensions": False,
22+
"enable_external_access": True,
23+
}
24+
assert "allow_community_extensions" not in config.pool_config
25+
assert "allow_unsigned_extensions" not in config.pool_config
26+
assert "enable_external_access" not in config.pool_config
27+
28+
29+
def test_duckdb_config_merges_existing_extension_flags() -> None:
30+
config = DuckDBConfig(
31+
pool_config={"database": ":memory:", "allow_community_extensions": True},
32+
driver_features={"extension_flags": {"custom": "value"}},
33+
)
34+
35+
flags = config.driver_features.get("extension_flags")
36+
assert flags == {"custom": "value", "allow_community_extensions": True}

0 commit comments

Comments
 (0)