22import os
33import pickle
44import sqlite3
5+ from pathlib import Path
56import asyncstdlib as a
67
78USE_CACHE = True if os .getenv ("NO_CACHE" ) != "1" else False
1415)
1516
1617
18+ def _ensure_dir ():
19+ path = Path (CACHE_LOCATION ).parent
20+ if not path .exists ():
21+ path .mkdir (parents = True , exist_ok = True )
22+
23+
1724def _get_table_name (func ):
1825 """Convert "ClassName.method_name" to "ClassName_method_name"""
1926 return func .__qualname__ .replace ("." , "_" )
@@ -74,22 +81,28 @@ def _insert_into_cache(c, conn, table_name, key, result, chain):
7481 pass
7582
7683
84+ def _shared_inner_fn_logic (func , self , args , kwargs ):
85+ _ensure_dir ()
86+ conn = sqlite3 .connect (CACHE_LOCATION )
87+ c = conn .cursor ()
88+ table_name = _get_table_name (func )
89+ _create_table (c , conn , table_name )
90+ key = pickle .dumps ((args , kwargs ))
91+ chain = self .url
92+ if not (local_chain := _check_if_local (chain )) or not USE_CACHE :
93+ result = _retrieve_from_cache (c , table_name , key , chain )
94+ else :
95+ result = None
96+ return c , conn , table_name , key , result , chain , local_chain
97+
98+
7799def sql_lru_cache (maxsize = None ):
78100 def decorator (func ):
79- conn = sqlite3 .connect (CACHE_LOCATION )
80- c = conn .cursor ()
81- table_name = _get_table_name (func )
82- _create_table (c , conn , table_name )
83-
84101 @functools .lru_cache (maxsize = maxsize )
85102 def inner (self , * args , ** kwargs ):
86- c = conn .cursor ()
87- key = pickle .dumps ((args , kwargs ))
88- chain = self .url
89- if not (local_chain := _check_if_local (chain )) or not USE_CACHE :
90- result = _retrieve_from_cache (c , table_name , key , chain )
91- if result is not None :
92- return result
103+ c , conn , table_name , key , result , chain , local_chain = (
104+ _shared_inner_fn_logic (func , self , args , kwargs )
105+ )
93106
94107 # If not in DB, call func and store in DB
95108 result = func (self , * args , ** kwargs )
@@ -106,21 +119,11 @@ def inner(self, *args, **kwargs):
106119
107120def async_sql_lru_cache (maxsize = None ):
108121 def decorator (func ):
109- conn = sqlite3 .connect (CACHE_LOCATION )
110- c = conn .cursor ()
111- table_name = _get_table_name (func )
112- _create_table (c , conn , table_name )
113-
114122 @a .lru_cache (maxsize = maxsize )
115123 async def inner (self , * args , ** kwargs ):
116- c = conn .cursor ()
117- key = pickle .dumps ((args , kwargs ))
118- chain = self .url
119-
120- if not (local_chain := _check_if_local (chain )) or not USE_CACHE :
121- result = _retrieve_from_cache (c , table_name , key , chain )
122- if result is not None :
123- return result
124+ c , conn , table_name , key , result , chain , local_chain = (
125+ _shared_inner_fn_logic (func , self , args , kwargs )
126+ )
124127
125128 # If not in DB, call func and store in DB
126129 result = await func (self , * args , ** kwargs )
0 commit comments