11"""Oracle ADK store for Google Agent Development Kit session/event storage."""
22
3+ from decimal import Decimal
34from enum import Enum
4- from typing import TYPE_CHECKING , Any , Final
5+ from typing import TYPE_CHECKING , Any , Final , cast
56
67import oracledb
78
89from sqlspec import SQL
10+ from sqlspec .adapters .oracledb .data_dictionary import (
11+ OracleAsyncDataDictionary ,
12+ OracleSyncDataDictionary ,
13+ OracleVersionInfo ,
14+ )
915from sqlspec .extensions .adk import BaseAsyncADKStore , BaseSyncADKStore , EventRecord , SessionRecord
1016from sqlspec .utils .logging import get_logger
1117from sqlspec .utils .serializers import from_json , to_json
@@ -33,6 +39,41 @@ class JSONStorageType(str, Enum):
3339 BLOB_PLAIN = "blob_plain"
3440
3541
42+ def _coerce_decimal_values (value : Any ) -> Any :
43+ if isinstance (value , Decimal ):
44+ return float (value )
45+ if isinstance (value , dict ):
46+ return {key : _coerce_decimal_values (val ) for key , val in value .items ()}
47+ if isinstance (value , list ):
48+ return [_coerce_decimal_values (item ) for item in value ]
49+ if isinstance (value , tuple ):
50+ return tuple (_coerce_decimal_values (item ) for item in value )
51+ if isinstance (value , set ):
52+ return {_coerce_decimal_values (item ) for item in value }
53+ if isinstance (value , frozenset ):
54+ return frozenset (_coerce_decimal_values (item ) for item in value )
55+ return value
56+
57+
58+ def _storage_type_from_version (version_info : "OracleVersionInfo | None" ) -> JSONStorageType :
59+ """Determine JSON storage type based on Oracle version metadata."""
60+
61+ if version_info and version_info .supports_native_json ():
62+ logger .debug ("Detected Oracle %s with compatible >= 20, using JSON_NATIVE" , version_info )
63+ return JSONStorageType .JSON_NATIVE
64+
65+ if version_info and version_info .supports_json_blob ():
66+ logger .debug ("Detected Oracle %s, using BLOB_JSON (recommended)" , version_info )
67+ return JSONStorageType .BLOB_JSON
68+
69+ if version_info :
70+ logger .debug ("Detected Oracle %s (pre-12c), using BLOB_PLAIN" , version_info )
71+ return JSONStorageType .BLOB_PLAIN
72+
73+ logger .warning ("Oracle version could not be detected; defaulting to BLOB_JSON storage" )
74+ return JSONStorageType .BLOB_JSON
75+
76+
3677def _to_oracle_bool (value : "bool | None" ) -> "int | None" :
3778 """Convert Python boolean to Oracle NUMBER(1).
3879
@@ -103,7 +144,7 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]):
103144 - Configuration is read from config.extension_config["adk"]
104145 """
105146
106- __slots__ = ("_in_memory" , "_json_storage_type" )
147+ __slots__ = ("_in_memory" , "_json_storage_type" , "_oracle_version_info" )
107148
108149 def __init__ (self , config : "OracleAsyncConfig" ) -> None :
109150 """Initialize Oracle ADK store.
@@ -120,6 +161,7 @@ def __init__(self, config: "OracleAsyncConfig") -> None:
120161 """
121162 super ().__init__ (config )
122163 self ._json_storage_type : JSONStorageType | None = None
164+ self ._oracle_version_info : OracleVersionInfo | None = None
123165
124166 adk_config = config .extension_config .get ("adk" , {})
125167 self ._in_memory : bool = bool (adk_config .get ("in_memory" , False ))
@@ -160,44 +202,24 @@ async def _detect_json_storage_type(self) -> JSONStorageType:
160202 if self ._json_storage_type is not None :
161203 return self ._json_storage_type
162204
163- async with self ._config .provide_connection () as conn :
164- cursor = conn .cursor ()
165- await cursor .execute (
166- """
167- SELECT version FROM product_component_version
168- WHERE product LIKE 'Oracle%Database%'
169- """
170- )
171- row = await cursor .fetchone ()
172-
173- if row is None :
174- logger .warning ("Could not detect Oracle version, defaulting to BLOB_JSON" )
175- self ._json_storage_type = JSONStorageType .BLOB_JSON
176- return self ._json_storage_type
177-
178- version_str = str (row [0 ])
179- version_parts = version_str .split ("." )
180- major_version = int (version_parts [0 ]) if version_parts else 0
181-
182- if major_version >= ORACLE_MIN_JSON_NATIVE_VERSION :
183- await cursor .execute ("SELECT value FROM v$parameter WHERE name = 'compatible'" )
184- compatible_row = await cursor .fetchone ()
185- if compatible_row :
186- compatible_parts = str (compatible_row [0 ]).split ("." )
187- compatible_major = int (compatible_parts [0 ]) if compatible_parts else 0
188- if compatible_major >= ORACLE_MIN_JSON_NATIVE_COMPATIBLE :
189- logger .info ("Detected Oracle %s with compatible >= 20, using JSON_NATIVE" , version_str )
190- self ._json_storage_type = JSONStorageType .JSON_NATIVE
191- return self ._json_storage_type
192-
193- if major_version >= ORACLE_MIN_JSON_BLOB_VERSION :
194- logger .info ("Detected Oracle %s, using BLOB_JSON (recommended)" , version_str )
195- self ._json_storage_type = JSONStorageType .BLOB_JSON
196- return self ._json_storage_type
197-
198- logger .info ("Detected Oracle %s (pre-12c), using BLOB_PLAIN" , version_str )
199- self ._json_storage_type = JSONStorageType .BLOB_PLAIN
200- return self ._json_storage_type
205+ version_info = await self ._get_version_info ()
206+ self ._json_storage_type = _storage_type_from_version (version_info )
207+ return self ._json_storage_type
208+
209+ async def _get_version_info (self ) -> "OracleVersionInfo | None" :
210+ """Return cached Oracle version info using Oracle data dictionary."""
211+
212+ if self ._oracle_version_info is not None :
213+ return self ._oracle_version_info
214+
215+ async with self ._config .provide_session () as driver :
216+ dictionary = OracleAsyncDataDictionary ()
217+ self ._oracle_version_info = await dictionary .get_version (driver )
218+
219+ if self ._oracle_version_info is None :
220+ logger .warning ("Could not detect Oracle version, defaulting to BLOB_JSON storage" )
221+
222+ return self ._oracle_version_info
201223
202224 async def _serialize_state (self , state : "dict[str, Any]" ) -> "str | bytes" :
203225 """Serialize state dictionary to appropriate format based on storage type.
@@ -232,7 +254,7 @@ async def _deserialize_state(self, data: Any) -> "dict[str, Any]":
232254 data = await data .read ()
233255
234256 if isinstance (data , dict ):
235- return data
257+ return cast ( "dict[str, Any]" , _coerce_decimal_values ( data ))
236258
237259 if isinstance (data , bytes ):
238260 return from_json (data ) # type: ignore[no-any-return]
@@ -280,7 +302,7 @@ async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None":
280302 data = await data .read ()
281303
282304 if isinstance (data , dict ):
283- return data
305+ return cast ( "dict[str, Any]" , _coerce_decimal_values ( data ))
284306
285307 if isinstance (data , bytes ):
286308 return from_json (data ) # type: ignore[no-any-return]
@@ -490,7 +512,7 @@ async def create_tables(self) -> None:
490512 Uses version-appropriate table schema.
491513 """
492514 storage_type = await self ._detect_json_storage_type ()
493- logger .info ("Creating ADK tables with storage type: %s" , storage_type )
515+ logger .debug ("Creating ADK tables with storage type: %s" , storage_type )
494516
495517 async with self ._config .provide_session () as driver :
496518 await driver .execute_script (self ._get_create_sessions_table_sql_for_type (storage_type ))
@@ -561,16 +583,17 @@ async def get_session(self, session_id: str) -> "SessionRecord | None":
561583 State is deserialized using version-appropriate format.
562584 """
563585
564- sql = f"""
565- SELECT id, app_name, user_id, state, create_time, update_time
566- FROM { self ._session_table }
567- WHERE id = :id
568- """
569-
570586 try :
571587 async with self ._config .provide_connection () as conn :
572588 cursor = conn .cursor ()
573- await cursor .execute (sql , {"id" : session_id })
589+ await cursor .execute (
590+ f"""
591+ SELECT id, app_name, user_id, state, create_time, update_time
592+ FROM { self ._session_table }
593+ WHERE id = :id
594+ """ ,
595+ {"id" : session_id },
596+ )
574597 row = await cursor .fetchone ()
575598
576599 if row is None :
@@ -881,7 +904,7 @@ class OracleSyncADKStore(BaseSyncADKStore["OracleSyncConfig"]):
881904 - Configuration is read from config.extension_config["adk"]
882905 """
883906
884- __slots__ = ("_in_memory" , "_json_storage_type" )
907+ __slots__ = ("_in_memory" , "_json_storage_type" , "_oracle_version_info" )
885908
886909 def __init__ (self , config : "OracleSyncConfig" ) -> None :
887910 """Initialize Oracle synchronous ADK store.
@@ -898,6 +921,7 @@ def __init__(self, config: "OracleSyncConfig") -> None:
898921 """
899922 super ().__init__ (config )
900923 self ._json_storage_type : JSONStorageType | None = None
924+ self ._oracle_version_info : OracleVersionInfo | None = None
901925
902926 adk_config = config .extension_config .get ("adk" , {})
903927 self ._in_memory : bool = bool (adk_config .get ("in_memory" , False ))
@@ -938,44 +962,24 @@ def _detect_json_storage_type(self) -> JSONStorageType:
938962 if self ._json_storage_type is not None :
939963 return self ._json_storage_type
940964
941- with self ._config .provide_connection () as conn :
942- cursor = conn .cursor ()
943- cursor .execute (
944- """
945- SELECT version FROM product_component_version
946- WHERE product LIKE 'Oracle%Database%'
947- """
948- )
949- row = cursor .fetchone ()
950-
951- if row is None :
952- logger .warning ("Could not detect Oracle version, defaulting to BLOB_JSON" )
953- self ._json_storage_type = JSONStorageType .BLOB_JSON
954- return self ._json_storage_type
955-
956- version_str = str (row [0 ])
957- version_parts = version_str .split ("." )
958- major_version = int (version_parts [0 ]) if version_parts else 0
959-
960- if major_version >= ORACLE_MIN_JSON_NATIVE_VERSION :
961- cursor .execute ("SELECT value FROM v$parameter WHERE name = 'compatible'" )
962- compatible_row = cursor .fetchone ()
963- if compatible_row :
964- compatible_parts = str (compatible_row [0 ]).split ("." )
965- compatible_major = int (compatible_parts [0 ]) if compatible_parts else 0
966- if compatible_major >= ORACLE_MIN_JSON_NATIVE_COMPATIBLE :
967- logger .info ("Detected Oracle %s with compatible >= 20, using JSON_NATIVE" , version_str )
968- self ._json_storage_type = JSONStorageType .JSON_NATIVE
969- return self ._json_storage_type
970-
971- if major_version >= ORACLE_MIN_JSON_BLOB_VERSION :
972- logger .info ("Detected Oracle %s, using BLOB_JSON (recommended)" , version_str )
973- self ._json_storage_type = JSONStorageType .BLOB_JSON
974- return self ._json_storage_type
975-
976- logger .info ("Detected Oracle %s (pre-12c), using BLOB_PLAIN" , version_str )
977- self ._json_storage_type = JSONStorageType .BLOB_PLAIN
978- return self ._json_storage_type
965+ version_info = self ._get_version_info ()
966+ self ._json_storage_type = _storage_type_from_version (version_info )
967+ return self ._json_storage_type
968+
969+ def _get_version_info (self ) -> "OracleVersionInfo | None" :
970+ """Return cached Oracle version info using Oracle data dictionary."""
971+
972+ if self ._oracle_version_info is not None :
973+ return self ._oracle_version_info
974+
975+ with self ._config .provide_session () as driver :
976+ dictionary = OracleSyncDataDictionary ()
977+ self ._oracle_version_info = dictionary .get_version (driver )
978+
979+ if self ._oracle_version_info is None :
980+ logger .warning ("Could not detect Oracle version, defaulting to BLOB_JSON storage" )
981+
982+ return self ._oracle_version_info
979983
980984 def _serialize_state (self , state : "dict[str, Any]" ) -> "str | bytes" :
981985 """Serialize state dictionary to appropriate format based on storage type.
@@ -1010,7 +1014,7 @@ def _deserialize_state(self, data: Any) -> "dict[str, Any]":
10101014 data = data .read ()
10111015
10121016 if isinstance (data , dict ):
1013- return data
1017+ return cast ( "dict[str, Any]" , _coerce_decimal_values ( data ))
10141018
10151019 if isinstance (data , bytes ):
10161020 return from_json (data ) # type: ignore[no-any-return]
@@ -1058,7 +1062,7 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None":
10581062 data = data .read ()
10591063
10601064 if isinstance (data , dict ):
1061- return data
1065+ return cast ( "dict[str, Any]" , _coerce_decimal_values ( data ))
10621066
10631067 if isinstance (data , bytes ):
10641068 return from_json (data ) # type: ignore[no-any-return]
0 commit comments