11import asyncio
2+ import inspect
23from collections import OrderedDict
34import functools
45import logging
56import os
67import pickle
78import sqlite3
89from pathlib import Path
9- from typing import Callable , Any , Awaitable , Hashable
10-
10+ from typing import Callable , Any , Awaitable , Hashable , Optional
1111
1212USE_CACHE = True if os .getenv ("NO_CACHE" ) != "1" else False
1313CACHE_LOCATION = (
@@ -208,22 +208,33 @@ def __init__(
208208 self ,
209209 max_size : int ,
210210 method : Callable [..., Awaitable [Any ]],
211- cache_key_index : int = 0 ,
211+ cache_key_index : Optional [ int ] = 0 ,
212212 ):
213213 """
214214 Args:
215215 max_size: max size of the cache (in items)
216216 method: the function to cache
217- cache_key_index: if the method takes multiple args, only one will be used as the cache key. This is the
218- index of that cache key in the args list (default is the first arg)
217+ cache_key_index: if the method takes multiple args, this is the index of that cache key in the args list
218+ (default is the first arg). By setting this to `None`, it will use all args as the cache key.
219219 """
220220 self ._inflight : dict [Hashable , asyncio .Future ] = {}
221221 self ._method = method
222222 self ._cache = LRUCache (max_size = max_size )
223223 self ._cache_key_index = cache_key_index
224224
225- async def __call__ (self , * args : Any ) -> Any :
226- key = args [self ._cache_key_index ]
225+ def make_cache_key (self , args : tuple , kwargs : dict ) -> Hashable :
226+ bound = inspect .signature (self ._method ).bind (* args , ** kwargs )
227+ bound .apply_defaults ()
228+
229+ if self ._cache_key_index is not None :
230+ key_name = list (bound .arguments )[self ._cache_key_index ]
231+ return bound .arguments [key_name ]
232+
233+ return (tuple (bound .arguments .items ()),)
234+
235+ async def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
236+ key = self .make_cache_key (args , kwargs )
237+
227238 if item := self ._cache .get (key ):
228239 return item
229240
@@ -235,7 +246,7 @@ async def __call__(self, *args: Any) -> Any:
235246 self ._inflight [key ] = future
236247
237248 try :
238- result = await self ._method (* args )
249+ result = await self ._method (* args , ** kwargs )
239250 self ._cache .set (key , result )
240251 future .set_result (result )
241252 return result
0 commit comments