1111
1212import requests
1313
14+ # this will let us do any other caching we might need in the future in the same
15+ # cache dir (adjacent to "downloads")
16+ _CACHEDIR_NAME = os .path .join ("check_jsonschema" , "downloads" )
17+
18+ _LASTMOD_FMT = "%a, %d %b %Y %H:%M:%S %Z"
19+
20+
21+ def _get_default_cache_dir () -> str | None :
22+ sysname = platform .system ()
23+
24+ # on windows, try to get the appdata env var
25+ # this *could* result in cache_dir=None, which is fine, just skip caching in
26+ # that case
27+ if sysname == "Windows" :
28+ cache_dir = os .getenv ("LOCALAPPDATA" , os .getenv ("APPDATA" ))
29+ # macOS -> app support dir
30+ elif sysname == "Darwin" :
31+ cache_dir = os .path .expanduser ("~/Library/Caches" )
32+ # default for unknown platforms, namely linux behavior
33+ # use XDG env var and default to ~/.cache/
34+ else :
35+ cache_dir = os .getenv ("XDG_CACHE_HOME" , os .path .expanduser ("~/.cache" ))
36+
37+ if cache_dir :
38+ cache_dir = os .path .join (cache_dir , _CACHEDIR_NAME )
39+
40+ return cache_dir
41+
42+
43+ def _lastmod_from_response (response : requests .Response ) -> float :
44+ try :
45+ return time .mktime (
46+ time .strptime (response .headers ["last-modified" ], _LASTMOD_FMT )
47+ )
48+ # OverflowError: time outside of platform-specific bounds
49+ # ValueError: malformed/unparseable
50+ # LookupError: no such header
51+ except (OverflowError , ValueError , LookupError ):
52+ return 0.0
53+
54+
55+ def _get_request (
56+ file_url : str , * , response_ok : t .Callable [[requests .Response ], bool ]
57+ ) -> requests .Response :
58+ try :
59+ r : requests .Response | None = None
60+ for _attempt in range (3 ):
61+ r = requests .get (file_url , stream = True )
62+ if r .ok and response_ok (r ):
63+ return r
64+ assert r is not None
65+ raise FailedDownloadError (
66+ f"got response with status={ r .status_code } , retries exhausted"
67+ )
68+ except requests .RequestException as e :
69+ raise FailedDownloadError ("encountered error during download" ) from e
70+
71+
72+ def _atomic_write (dest : str , content : bytes ) -> None :
73+ # download to a temp file and then move to the dest
74+ # this makes the download safe if run in parallel (parallel runs
75+ # won't create a new empty file for writing and cause failures)
76+ fp = tempfile .NamedTemporaryFile (mode = "wb" , delete = False )
77+ fp .write (content )
78+ fp .close ()
79+ shutil .copy (fp .name , dest )
80+ os .remove (fp .name )
81+
82+
83+ def _cache_hit (cachefile : str , response : requests .Response ) -> bool :
84+ # no file? miss
85+ if not os .path .exists (cachefile ):
86+ return False
87+
88+ # compare mtime on any cached file against the remote last-modified time
89+ # it is considered a hit if the local file is at least as new as the remote file
90+ local_mtime = os .path .getmtime (cachefile )
91+ remote_mtime = _lastmod_from_response (response )
92+ return local_mtime >= remote_mtime
93+
1494
1595class FailedDownloadError (Exception ):
1696 pass
1797
1898
1999class CacheDownloader :
20- _LASTMOD_FMT = "%a, %d %b %Y %H:%M:%S %Z"
21-
22- # changed in v0.5.0
23- # original cache dir was "jsonschema_validate"
24- # this will let us do any other caching we might need in the future in the same
25- # cache dir (adjacent to "downloads")
26- _CACHEDIR_NAME = os .path .join ("check_jsonschema" , "downloads" )
27-
28100 def __init__ (
29101 self ,
30- file_url : str ,
31- filename : str | None = None ,
32102 cache_dir : str | None = None ,
33103 disable_cache : bool = False ,
34104 validation_callback : t .Callable [[bytes ], t .Any ] | None = None ,
35105 ):
36- self ._file_url = file_url
37- self ._filename = filename or file_url .split ("/" )[- 1 ]
38- self ._cache_dir = cache_dir or self ._compute_default_cache_dir ()
106+ self ._cache_dir = cache_dir or _get_default_cache_dir ()
39107 self ._disable_cache = disable_cache
40108 self ._validation_callback = validation_callback
41109
42- def _compute_default_cache_dir (self ) -> str | None :
43- sysname = platform .system ()
44-
45- # on windows, try to get the appdata env var
46- # this *could* result in cache_dir=None, which is fine, just skip caching in
47- # that case
48- if sysname == "Windows" :
49- cache_dir = os .getenv ("LOCALAPPDATA" , os .getenv ("APPDATA" ))
50- # macOS -> app support dir
51- elif sysname == "Darwin" :
52- cache_dir = os .path .expanduser ("~/Library/Caches" )
53- # default for unknown platforms, namely linux behavior
54- # use XDG env var and default to ~/.cache/
55- else :
56- cache_dir = os .getenv ("XDG_CACHE_HOME" , os .path .expanduser ("~/.cache" ))
57-
58- if cache_dir :
59- cache_dir = os .path .join (cache_dir , self ._CACHEDIR_NAME )
60-
61- return cache_dir
62-
63- def _get_request (
64- self , * , response_ok : t .Callable [[requests .Response ], bool ]
65- ) -> requests .Response :
66- try :
67- r : requests .Response | None = None
68- for _attempt in range (3 ):
69- r = requests .get (self ._file_url , stream = True )
70- if r .ok and response_ok (r ):
71- return r
72- assert r is not None
73- raise FailedDownloadError (
74- f"got response with status={ r .status_code } , retries exhausted"
75- )
76- except requests .RequestException as e :
77- raise FailedDownloadError ("encountered error during download" ) from e
78-
79- def _lastmod_from_response (self , response : requests .Response ) -> float :
80- try :
81- return time .mktime (
82- time .strptime (response .headers ["last-modified" ], self ._LASTMOD_FMT )
83- )
84- # OverflowError: time outside of platform-specific bounds
85- # ValueError: malformed/unparseable
86- # LookupError: no such header
87- except (OverflowError , ValueError , LookupError ):
88- return 0.0
89-
90- def _cache_hit (self , cachefile : str , response : requests .Response ) -> bool :
91- # no file? miss
92- if not os .path .exists (cachefile ):
93- return False
94-
95- # compare mtime on any cached file against the remote last-modified time
96- # it is considered a hit if the local file is at least as new as the remote file
97- local_mtime = os .path .getmtime (cachefile )
98- remote_mtime = self ._lastmod_from_response (response )
99- return local_mtime >= remote_mtime
100-
101- def _write (self , dest : str , response : requests .Response ) -> None :
102- # download to a temp file and then move to the dest
103- # this makes the download safe if run in parallel (parallel runs
104- # won't create a new empty file for writing and cause failures)
105- fp = tempfile .NamedTemporaryFile (mode = "wb" , delete = False )
106- fp .write (response .content )
107- fp .close ()
108- shutil .copy (fp .name , dest )
109- os .remove (fp .name )
110-
111110 def _validate (self , response : requests .Response ) -> bool :
112111 if not self ._validation_callback :
113112 return True
@@ -118,32 +117,52 @@ def _validate(self, response: requests.Response) -> bool:
118117 except ValueError :
119118 return False
120119
121- def _download (self ) -> str :
122- assert self ._cache_dir
120+ def _download (self , file_url : str , filename : str ) -> str :
121+ assert self ._cache_dir is not None
123122 os .makedirs (self ._cache_dir , exist_ok = True )
124- dest = os .path .join (self ._cache_dir , self . _filename )
123+ dest = os .path .join (self ._cache_dir , filename )
125124
126125 def check_response_for_download (r : requests .Response ) -> bool :
127126 # if the response indicates a cache hit, treat it as valid
128127 # this ensures that we short-circuit any further evaluation immediately on
129128 # a hit
130- if self . _cache_hit (dest , r ):
129+ if _cache_hit (dest , r ):
131130 return True
132131 # we now know it's not a hit, so validate the content (forces download)
133132 return self ._validate (r )
134133
135- response = self . _get_request (response_ok = check_response_for_download )
134+ response = _get_request (file_url , response_ok = check_response_for_download )
136135 # check to see if we have a file which matches the connection
137136 # only download if we do not (cache miss, vs hit)
138- if not self . _cache_hit (dest , response ):
139- self . _write (dest , response )
137+ if not _cache_hit (dest , response ):
138+ _atomic_write (dest , response . content )
140139
141140 return dest
142141
143142 @contextlib .contextmanager
144- def open (self ) -> t .Iterator [t .IO [bytes ]]:
143+ def open (self , file_url : str , filename : str ) -> t .Iterator [t .IO [bytes ]]:
145144 if (not self ._cache_dir ) or self ._disable_cache :
146- yield io .BytesIO (self . _get_request (response_ok = self ._validate ).content )
145+ yield io .BytesIO (_get_request (file_url , response_ok = self ._validate ).content )
147146 else :
148- with open (self ._download (), "rb" ) as fp :
147+ with open (self ._download (file_url , filename ), "rb" ) as fp :
149148 yield fp
149+
150+ def bind (self , file_url : str , filename : str | None = None ) -> BoundCacheDownloader :
151+ return BoundCacheDownloader (file_url , filename , self )
152+
153+
154+ class BoundCacheDownloader :
155+ def __init__ (
156+ self ,
157+ file_url : str ,
158+ filename : str | None ,
159+ downloader : CacheDownloader ,
160+ ):
161+ self ._file_url = file_url
162+ self ._filename = filename or file_url .split ("/" )[- 1 ]
163+ self ._downloader = downloader
164+
165+ @contextlib .contextmanager
166+ def open (self ) -> t .Iterator [t .IO [bytes ]]:
167+ with self ._downloader .open (self ._file_url , self ._filename ) as fp :
168+ yield fp
0 commit comments