22import os
33import pickle
44import sqlite3
5+ from pathlib import Path
56import asyncstdlib as a
67
78USE_CACHE = True if os .getenv ("NO_CACHE" ) != "1" else False
1415)
1516
1617
18+ def _ensure_dir ():
19+ path = Path (CACHE_LOCATION ).parent
20+ if not path .exists ():
21+ path .mkdir (parents = True , exist_ok = True )
22+
23+
1724def _get_table_name (func ):
1825 """Convert "ClassName.method_name" to "ClassName_method_name"""
1926 return func .__qualname__ .replace ("." , "_" )
@@ -74,22 +81,32 @@ def _insert_into_cache(c, conn, table_name, key, result, chain):
7481 pass
7582
7683
77- def sql_lru_cache (maxsize = None ):
78- def decorator (func ):
84+ def _shared_inner_fn_logic (func , self , args , kwargs ):
85+ chain = self .url
86+ if not (local_chain := _check_if_local (chain )) or not USE_CACHE :
87+ _ensure_dir ()
7988 conn = sqlite3 .connect (CACHE_LOCATION )
8089 c = conn .cursor ()
8190 table_name = _get_table_name (func )
8291 _create_table (c , conn , table_name )
92+ key = pickle .dumps ((args , kwargs ))
93+ result = _retrieve_from_cache (c , table_name , key , chain )
94+ else :
95+ result = None
96+ c = None
97+ conn = None
98+ table_name = None
99+ key = None
100+ return c , conn , table_name , key , result , chain , local_chain
101+
83102
103+ def sql_lru_cache (maxsize = None ):
104+ def decorator (func ):
84105 @functools .lru_cache (maxsize = maxsize )
85106 def inner (self , * args , ** kwargs ):
86- c = conn .cursor ()
87- key = pickle .dumps ((args , kwargs ))
88- chain = self .url
89- if not (local_chain := _check_if_local (chain )) or not USE_CACHE :
90- result = _retrieve_from_cache (c , table_name , key , chain )
91- if result is not None :
92- return result
107+ c , conn , table_name , key , result , chain , local_chain = (
108+ _shared_inner_fn_logic (func , self , args , kwargs )
109+ )
93110
94111 # If not in DB, call func and store in DB
95112 result = func (self , * args , ** kwargs )
@@ -106,21 +123,11 @@ def inner(self, *args, **kwargs):
106123
107124def async_sql_lru_cache (maxsize = None ):
108125 def decorator (func ):
109- conn = sqlite3 .connect (CACHE_LOCATION )
110- c = conn .cursor ()
111- table_name = _get_table_name (func )
112- _create_table (c , conn , table_name )
113-
114126 @a .lru_cache (maxsize = maxsize )
115127 async def inner (self , * args , ** kwargs ):
116- c = conn .cursor ()
117- key = pickle .dumps ((args , kwargs ))
118- chain = self .url
119-
120- if not (local_chain := _check_if_local (chain )) or not USE_CACHE :
121- result = _retrieve_from_cache (c , table_name , key , chain )
122- if result is not None :
123- return result
128+ c , conn , table_name , key , result , chain , local_chain = (
129+ _shared_inner_fn_logic (func , self , args , kwargs )
130+ )
124131
125132 # If not in DB, call func and store in DB
126133 result = await func (self , * args , ** kwargs )
0 commit comments