Skip to content

Commit cdca6b3

Browse files
authored
Merge branch 'main' into 173_usage_configuration
2 parents e0b1082 + f1729a4 commit cdca6b3

File tree

5 files changed

+358
-96
lines changed

5 files changed

+358
-96
lines changed

sqlspec/adapters/oracledb/adk/store.py

Lines changed: 95 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
"""Oracle ADK store for Google Agent Development Kit session/event storage."""
22

3+
from decimal import Decimal
34
from enum import Enum
4-
from typing import TYPE_CHECKING, Any, Final
5+
from typing import TYPE_CHECKING, Any, Final, cast
56

67
import oracledb
78

89
from sqlspec import SQL
10+
from sqlspec.adapters.oracledb.data_dictionary import (
11+
OracleAsyncDataDictionary,
12+
OracleSyncDataDictionary,
13+
OracleVersionInfo,
14+
)
915
from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord
1016
from sqlspec.utils.logging import get_logger
1117
from sqlspec.utils.serializers import from_json, to_json
@@ -33,6 +39,41 @@ class JSONStorageType(str, Enum):
3339
BLOB_PLAIN = "blob_plain"
3440

3541

42+
def _coerce_decimal_values(value: Any) -> Any:
43+
if isinstance(value, Decimal):
44+
return float(value)
45+
if isinstance(value, dict):
46+
return {key: _coerce_decimal_values(val) for key, val in value.items()}
47+
if isinstance(value, list):
48+
return [_coerce_decimal_values(item) for item in value]
49+
if isinstance(value, tuple):
50+
return tuple(_coerce_decimal_values(item) for item in value)
51+
if isinstance(value, set):
52+
return {_coerce_decimal_values(item) for item in value}
53+
if isinstance(value, frozenset):
54+
return frozenset(_coerce_decimal_values(item) for item in value)
55+
return value
56+
57+
58+
def _storage_type_from_version(version_info: "OracleVersionInfo | None") -> JSONStorageType:
59+
"""Determine JSON storage type based on Oracle version metadata."""
60+
61+
if version_info and version_info.supports_native_json():
62+
logger.debug("Detected Oracle %s with compatible >= 20, using JSON_NATIVE", version_info)
63+
return JSONStorageType.JSON_NATIVE
64+
65+
if version_info and version_info.supports_json_blob():
66+
logger.debug("Detected Oracle %s, using BLOB_JSON (recommended)", version_info)
67+
return JSONStorageType.BLOB_JSON
68+
69+
if version_info:
70+
logger.debug("Detected Oracle %s (pre-12c), using BLOB_PLAIN", version_info)
71+
return JSONStorageType.BLOB_PLAIN
72+
73+
logger.warning("Oracle version could not be detected; defaulting to BLOB_JSON storage")
74+
return JSONStorageType.BLOB_JSON
75+
76+
3677
def _to_oracle_bool(value: "bool | None") -> "int | None":
3778
"""Convert Python boolean to Oracle NUMBER(1).
3879
@@ -103,7 +144,7 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]):
103144
- Configuration is read from config.extension_config["adk"]
104145
"""
105146

106-
__slots__ = ("_in_memory", "_json_storage_type")
147+
__slots__ = ("_in_memory", "_json_storage_type", "_oracle_version_info")
107148

108149
def __init__(self, config: "OracleAsyncConfig") -> None:
109150
"""Initialize Oracle ADK store.
@@ -120,6 +161,7 @@ def __init__(self, config: "OracleAsyncConfig") -> None:
120161
"""
121162
super().__init__(config)
122163
self._json_storage_type: JSONStorageType | None = None
164+
self._oracle_version_info: OracleVersionInfo | None = None
123165

124166
adk_config = config.extension_config.get("adk", {})
125167
self._in_memory: bool = bool(adk_config.get("in_memory", False))
@@ -160,44 +202,24 @@ async def _detect_json_storage_type(self) -> JSONStorageType:
160202
if self._json_storage_type is not None:
161203
return self._json_storage_type
162204

163-
async with self._config.provide_connection() as conn:
164-
cursor = conn.cursor()
165-
await cursor.execute(
166-
"""
167-
SELECT version FROM product_component_version
168-
WHERE product LIKE 'Oracle%Database%'
169-
"""
170-
)
171-
row = await cursor.fetchone()
172-
173-
if row is None:
174-
logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON")
175-
self._json_storage_type = JSONStorageType.BLOB_JSON
176-
return self._json_storage_type
177-
178-
version_str = str(row[0])
179-
version_parts = version_str.split(".")
180-
major_version = int(version_parts[0]) if version_parts else 0
181-
182-
if major_version >= ORACLE_MIN_JSON_NATIVE_VERSION:
183-
await cursor.execute("SELECT value FROM v$parameter WHERE name = 'compatible'")
184-
compatible_row = await cursor.fetchone()
185-
if compatible_row:
186-
compatible_parts = str(compatible_row[0]).split(".")
187-
compatible_major = int(compatible_parts[0]) if compatible_parts else 0
188-
if compatible_major >= ORACLE_MIN_JSON_NATIVE_COMPATIBLE:
189-
logger.info("Detected Oracle %s with compatible >= 20, using JSON_NATIVE", version_str)
190-
self._json_storage_type = JSONStorageType.JSON_NATIVE
191-
return self._json_storage_type
192-
193-
if major_version >= ORACLE_MIN_JSON_BLOB_VERSION:
194-
logger.info("Detected Oracle %s, using BLOB_JSON (recommended)", version_str)
195-
self._json_storage_type = JSONStorageType.BLOB_JSON
196-
return self._json_storage_type
197-
198-
logger.info("Detected Oracle %s (pre-12c), using BLOB_PLAIN", version_str)
199-
self._json_storage_type = JSONStorageType.BLOB_PLAIN
200-
return self._json_storage_type
205+
version_info = await self._get_version_info()
206+
self._json_storage_type = _storage_type_from_version(version_info)
207+
return self._json_storage_type
208+
209+
async def _get_version_info(self) -> "OracleVersionInfo | None":
210+
"""Return cached Oracle version info using Oracle data dictionary."""
211+
212+
if self._oracle_version_info is not None:
213+
return self._oracle_version_info
214+
215+
async with self._config.provide_session() as driver:
216+
dictionary = OracleAsyncDataDictionary()
217+
self._oracle_version_info = await dictionary.get_version(driver)
218+
219+
if self._oracle_version_info is None:
220+
logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON storage")
221+
222+
return self._oracle_version_info
201223

202224
async def _serialize_state(self, state: "dict[str, Any]") -> "str | bytes":
203225
"""Serialize state dictionary to appropriate format based on storage type.
@@ -232,7 +254,7 @@ async def _deserialize_state(self, data: Any) -> "dict[str, Any]":
232254
data = await data.read()
233255

234256
if isinstance(data, dict):
235-
return data
257+
return cast("dict[str, Any]", _coerce_decimal_values(data))
236258

237259
if isinstance(data, bytes):
238260
return from_json(data) # type: ignore[no-any-return]
@@ -280,7 +302,7 @@ async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None":
280302
data = await data.read()
281303

282304
if isinstance(data, dict):
283-
return data
305+
return cast("dict[str, Any]", _coerce_decimal_values(data))
284306

285307
if isinstance(data, bytes):
286308
return from_json(data) # type: ignore[no-any-return]
@@ -490,7 +512,7 @@ async def create_tables(self) -> None:
490512
Uses version-appropriate table schema.
491513
"""
492514
storage_type = await self._detect_json_storage_type()
493-
logger.info("Creating ADK tables with storage type: %s", storage_type)
515+
logger.debug("Creating ADK tables with storage type: %s", storage_type)
494516

495517
async with self._config.provide_session() as driver:
496518
await driver.execute_script(self._get_create_sessions_table_sql_for_type(storage_type))
@@ -561,16 +583,17 @@ async def get_session(self, session_id: str) -> "SessionRecord | None":
561583
State is deserialized using version-appropriate format.
562584
"""
563585

564-
sql = f"""
565-
SELECT id, app_name, user_id, state, create_time, update_time
566-
FROM {self._session_table}
567-
WHERE id = :id
568-
"""
569-
570586
try:
571587
async with self._config.provide_connection() as conn:
572588
cursor = conn.cursor()
573-
await cursor.execute(sql, {"id": session_id})
589+
await cursor.execute(
590+
f"""
591+
SELECT id, app_name, user_id, state, create_time, update_time
592+
FROM {self._session_table}
593+
WHERE id = :id
594+
""",
595+
{"id": session_id},
596+
)
574597
row = await cursor.fetchone()
575598

576599
if row is None:
@@ -881,7 +904,7 @@ class OracleSyncADKStore(BaseSyncADKStore["OracleSyncConfig"]):
881904
- Configuration is read from config.extension_config["adk"]
882905
"""
883906

884-
__slots__ = ("_in_memory", "_json_storage_type")
907+
__slots__ = ("_in_memory", "_json_storage_type", "_oracle_version_info")
885908

886909
def __init__(self, config: "OracleSyncConfig") -> None:
887910
"""Initialize Oracle synchronous ADK store.
@@ -898,6 +921,7 @@ def __init__(self, config: "OracleSyncConfig") -> None:
898921
"""
899922
super().__init__(config)
900923
self._json_storage_type: JSONStorageType | None = None
924+
self._oracle_version_info: OracleVersionInfo | None = None
901925

902926
adk_config = config.extension_config.get("adk", {})
903927
self._in_memory: bool = bool(adk_config.get("in_memory", False))
@@ -938,44 +962,24 @@ def _detect_json_storage_type(self) -> JSONStorageType:
938962
if self._json_storage_type is not None:
939963
return self._json_storage_type
940964

941-
with self._config.provide_connection() as conn:
942-
cursor = conn.cursor()
943-
cursor.execute(
944-
"""
945-
SELECT version FROM product_component_version
946-
WHERE product LIKE 'Oracle%Database%'
947-
"""
948-
)
949-
row = cursor.fetchone()
950-
951-
if row is None:
952-
logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON")
953-
self._json_storage_type = JSONStorageType.BLOB_JSON
954-
return self._json_storage_type
955-
956-
version_str = str(row[0])
957-
version_parts = version_str.split(".")
958-
major_version = int(version_parts[0]) if version_parts else 0
959-
960-
if major_version >= ORACLE_MIN_JSON_NATIVE_VERSION:
961-
cursor.execute("SELECT value FROM v$parameter WHERE name = 'compatible'")
962-
compatible_row = cursor.fetchone()
963-
if compatible_row:
964-
compatible_parts = str(compatible_row[0]).split(".")
965-
compatible_major = int(compatible_parts[0]) if compatible_parts else 0
966-
if compatible_major >= ORACLE_MIN_JSON_NATIVE_COMPATIBLE:
967-
logger.info("Detected Oracle %s with compatible >= 20, using JSON_NATIVE", version_str)
968-
self._json_storage_type = JSONStorageType.JSON_NATIVE
969-
return self._json_storage_type
970-
971-
if major_version >= ORACLE_MIN_JSON_BLOB_VERSION:
972-
logger.info("Detected Oracle %s, using BLOB_JSON (recommended)", version_str)
973-
self._json_storage_type = JSONStorageType.BLOB_JSON
974-
return self._json_storage_type
975-
976-
logger.info("Detected Oracle %s (pre-12c), using BLOB_PLAIN", version_str)
977-
self._json_storage_type = JSONStorageType.BLOB_PLAIN
978-
return self._json_storage_type
965+
version_info = self._get_version_info()
966+
self._json_storage_type = _storage_type_from_version(version_info)
967+
return self._json_storage_type
968+
969+
def _get_version_info(self) -> "OracleVersionInfo | None":
970+
"""Return cached Oracle version info using Oracle data dictionary."""
971+
972+
if self._oracle_version_info is not None:
973+
return self._oracle_version_info
974+
975+
with self._config.provide_session() as driver:
976+
dictionary = OracleSyncDataDictionary()
977+
self._oracle_version_info = dictionary.get_version(driver)
978+
979+
if self._oracle_version_info is None:
980+
logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON storage")
981+
982+
return self._oracle_version_info
979983

980984
def _serialize_state(self, state: "dict[str, Any]") -> "str | bytes":
981985
"""Serialize state dictionary to appropriate format based on storage type.
@@ -1010,7 +1014,7 @@ def _deserialize_state(self, data: Any) -> "dict[str, Any]":
10101014
data = data.read()
10111015

10121016
if isinstance(data, dict):
1013-
return data
1017+
return cast("dict[str, Any]", _coerce_decimal_values(data))
10141018

10151019
if isinstance(data, bytes):
10161020
return from_json(data) # type: ignore[no-any-return]
@@ -1058,7 +1062,7 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None":
10581062
data = data.read()
10591063

10601064
if isinstance(data, dict):
1061-
return data
1065+
return cast("dict[str, Any]", _coerce_decimal_values(data))
10621066

10631067
if isinstance(data, bytes):
10641068
return from_json(data) # type: ignore[no-any-return]

sqlspec/core/parameters/_validator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
(?P<block_comment>/\*(?:[^*]|\*(?!/))*\*/) |
1919
(?P<pg_q_operator>\?\?|\?\||\?&) |
2020
(?P<pg_cast>::(?P<cast_type>\w+)) |
21+
(?P<sql_server_global>@@(?P<global_var_name>\w+)) |
2122
(?P<pyformat_named>%\((?P<pyformat_name>\w+)\)s) |
2223
(?P<pyformat_pos>%s) |
23-
(?P<positional_colon>:(?P<colon_num>\d+)) |
24-
(?P<named_colon>:(?P<colon_name>\w+)) |
25-
(?P<named_at>@(?P<at_name>\w+)) |
26-
(?P<numeric>\$(?P<numeric_num>\d+)) |
27-
(?P<named_dollar_param>\$(?P<dollar_param_name>\w+)) |
24+
(?P<positional_colon>(?<![A-Za-z0-9_]):(?P<colon_num>\d+)) |
25+
(?P<named_colon>(?<![A-Za-z0-9_]):(?P<colon_name>\w+)) |
26+
(?P<named_at>(?<![A-Za-z0-9_])@(?P<at_name>\w+)) |
27+
(?P<numeric>(?<![A-Za-z0-9_])\$(?P<numeric_num>\d+)) |
28+
(?P<named_dollar_param>(?<![A-Za-z0-9_])\$(?P<dollar_param_name>\w+)) |
2829
(?P<qmark>\?)
2930
""",
3031
re.VERBOSE | re.IGNORECASE | re.MULTILINE | re.DOTALL,
@@ -85,6 +86,7 @@ def extract_parameters(self, sql: str) -> "list[ParameterInfo]":
8586
"block_comment",
8687
"pg_q_operator",
8788
"pg_cast",
89+
"sql_server_global",
8890
)
8991

9092
for match in PARAMETER_REGEX.finditer(sql):
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Tests for Oracle ADK store Decimal coercion."""
2+
3+
from decimal import Decimal
4+
5+
import pytest
6+
7+
from sqlspec.adapters.oracledb.adk.store import OracleAsyncADKStore, OracleSyncADKStore
8+
9+
10+
@pytest.mark.asyncio
11+
async def test_oracle_async_adk_store_deserialize_dict_coerces_decimal() -> None:
12+
store = OracleAsyncADKStore.__new__(OracleAsyncADKStore) # type: ignore[call-arg]
13+
14+
payload = {"value": Decimal("1.25"), "nested": {"score": Decimal("0.5")}}
15+
16+
result = await store._deserialize_json_field(payload) # type: ignore[attr-defined]
17+
18+
assert result == {"value": 1.25, "nested": {"score": 0.5}}
19+
20+
21+
@pytest.mark.asyncio
22+
async def test_oracle_async_adk_store_deserialize_state_dict_coerces_decimal() -> None:
23+
store = OracleAsyncADKStore.__new__(OracleAsyncADKStore) # type: ignore[call-arg]
24+
25+
payload = {"state": Decimal("2.0")}
26+
27+
result = await store._deserialize_state(payload) # type: ignore[attr-defined]
28+
29+
assert result == {"state": 2.0}
30+
31+
32+
def test_oracle_sync_adk_store_deserialize_dict_coerces_decimal() -> None:
33+
store = OracleSyncADKStore.__new__(OracleSyncADKStore) # type: ignore[call-arg]
34+
35+
payload = {"value": Decimal("3.14"), "items": [Decimal("1.0"), Decimal("2.0")]}
36+
37+
result = store._deserialize_json_field(payload) # type: ignore[attr-defined]
38+
39+
assert result == {"value": 3.14, "items": [1.0, 2.0]}
40+
41+
42+
def test_oracle_sync_adk_store_deserialize_state_dict_coerces_decimal() -> None:
43+
store = OracleSyncADKStore.__new__(OracleSyncADKStore) # type: ignore[call-arg]
44+
45+
payload = {"state": Decimal("5.0")}
46+
47+
result = store._deserialize_state(payload) # type: ignore[attr-defined]
48+
49+
assert result == {"state": 5.0}

0 commit comments

Comments
 (0)