1515CACHE_WARNING += "and should never be used or depended upon as it is not supported! "
1616CACHE_WARNING += "All caching capabilities are not tested and may be removed/changed "
1717CACHE_WARNING += "without prior notice. Please proceed with caution!"
18+ CACHE_CLEARED = True
1819
1920
2021def get_njit_funcs ():
@@ -102,58 +103,78 @@ def _enable():
102103 raise
103104
104105
105- def _clear ():
106+ def _clear (cache_dir = None ):
106107 """
107108 Clear numba cache
108109
109110 Parameters
110111 ----------
111- None
112+ cache_dir : str, default None
113+ The path to the numba cache directory
112114
113115 Returns
114116 -------
115117 None
116118 """
117- site_pkg_dir = site .getsitepackages ()[0 ]
118- numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
119+ global CACHE_CLEARED
120+
121+ if cache_dir is not None :
122+ numba_cache_dir = str (cache_dir )
123+ else : # pragma: no cover
124+ site_pkg_dir = site .getsitepackages ()[0 ]
125+ numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
126+
119127 [f .unlink () for f in pathlib .Path (numba_cache_dir ).glob ("*nb*" ) if f .is_file ()]
120128
129+ CACHE_CLEARED = True
130+
121131
122- def clear ():
132+ def clear (cache_dir = None ):
123133 """
124134 Clear numba cache directory
125135
126136 Parameters
127137 ----------
128- None
138+ cache_dir : str, default None
139+ The path to the numba cache directory. When `cache_dir` is `None`, then this
140+ defaults to `site-packages/stumpy/__pycache__`.
129141
130142 Returns
131143 -------
132144 None
133145 """
134146 warnings .warn (CACHE_WARNING )
135- _clear ()
147+ _clear (cache_dir )
136148
137149 return
138150
139151
140- def _get_cache ():
152+ def _get_cache (cache_dir = None ):
141153 """
142154 Retrieve a list of cached numba functions
143155
144156 Parameters
145157 ----------
146- None
158+ cache_dir : str
159+ The path to the numba cache directory
147160
148161 Returns
149162 -------
150163 out : list
151164 A list of cached numba functions
152165 """
153166 warnings .warn (CACHE_WARNING )
154- site_pkg_dir = site .getsitepackages ()[0 ]
155- numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
156- return [f .name for f in pathlib .Path (numba_cache_dir ).glob ("*nb*" ) if f .is_file ()]
167+ if cache_dir is not None :
168+ numba_cache_dir = str (cache_dir )
169+ else : # pragma: no cover
170+ site_pkg_dir = site .getsitepackages ()[0 ]
171+ numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
172+
173+ return [
174+ f"{ numba_cache_dir } /{ f .name } "
175+ for f in pathlib .Path (numba_cache_dir ).glob ("*nb*" )
176+ if f .is_file ()
177+ ]
157178
158179
159180def _recompile ():
@@ -202,16 +223,24 @@ def _save():
202223 -------
203224 None
204225 """
226+ global CACHE_CLEARED
227+
228+ if not CACHE_CLEARED : # pragma: no cover
229+ msg = "Numba njit cached files are not cleared before saving/overwriting. "
230+ msg = "You may need to call `cache.clear()` before calling `cache.save()`."
231+ warnings .warn (msg )
232+
205233 _enable ()
206234 _recompile ()
207235
236+ CACHE_CLEARED = False
237+
208238 return
209239
210240
211241def save ():
212242 """
213- Save/overwrite all the cache data files of
214- all-so-far compiled njit functions.
243+ Save/overwrite all of the cached njit functions.
215244
216245 Parameters
217246 ----------
@@ -220,13 +249,40 @@ def save():
220249 Returns
221250 -------
222251 None
252+
253+ Notes
254+ -----
255+ The cache is never cleared before saving/overwriting and may be explicitly cleared
256+ by calling `cache.clear()` before saving. It is best practice to call `cache.save()`
257+ only after calling all of your `njit` functions. If `cache.save()` is called for the
258+ first time (before any `njit` function is called) then only the `.nbi` files (i.e.,
259+ the "cache index") for all `njit` functions are saved. As each `njit` function (and
260+ sub-functions) is called then their corresponding `.nbc` file (i.e., "object code")
261+ is saved. Each `.nbc` file will only be saved after its `njit` function is called
262+ at least once. However, subsequent calls to `cache.save()` (after clearing the cache
263+ via `cache.clear()`) will automatically save BOTH the `.nbi` files as well as the
264+ `.nbc` files as long as their `njit` function has been called at least once.
265+
266+ Examples
267+ --------
268+ >>> import stumpy
269+ >>> from stumpy import cache
270+ >>> import numpy as np
271+ >>> cache.clear()
272+ >>> mp = stumpy.stump(np.array([584., -11., 23., 79., 1001., 0., -19.]), m=3)
273+ >>> cache.save()
223274 """
224275 if numba .config .DISABLE_JIT :
225276 msg = "Could not save/cache function because NUMBA JIT is disabled"
226277 warnings .warn (msg )
227278 else : # pragma: no cover
228279 warnings .warn (CACHE_WARNING )
229280
281+ if numba .config .CACHE_DIR != "" : # pragma: no cover
282+ msg = "Found user specified `NUMBA_CACHE_DIR`/`numba.config.CACHE_DIR`. "
283+ msg += "The `stumpy` cache files may not be saved/cleared correctly!"
284+ warnings .warn (msg )
285+
230286 _save ()
231287
232288 return
0 commit comments