1111
1212import requests
1313
14- # this will let us do any other caching we might need in the future in the same
15- # cache dir (adjacent to "downloads")
16- _CACHEDIR_NAME = os .path .join ("check_jsonschema" , "downloads" )
17-
1814_LASTMOD_FMT = "%a, %d %b %Y %H:%M:%S %Z"
1915
2016
21- def _get_default_cache_dir () -> str | None :
17+ def _base_cache_dir () -> str | None :
2218 sysname = platform .system ()
2319
2420 # on windows, try to get the appdata env var
@@ -34,9 +30,13 @@ def _get_default_cache_dir() -> str | None:
3430 else :
3531 cache_dir = os .getenv ("XDG_CACHE_HOME" , os .path .expanduser ("~/.cache" ))
3632
37- if cache_dir :
38- cache_dir = os .path .join (cache_dir , _CACHEDIR_NAME )
33+ return cache_dir
3934
35+
36+ def _resolve_cache_dir (dirname : str = "downloads" ) -> str | None :
37+ cache_dir = _base_cache_dir ()
38+ if cache_dir :
39+ cache_dir = os .path .join (cache_dir , "check_jsonschema" , dirname )
4040 return cache_dir
4141
4242
@@ -55,18 +55,21 @@ def _lastmod_from_response(response: requests.Response) -> float:
5555def _get_request (
5656 file_url : str , * , response_ok : t .Callable [[requests .Response ], bool ]
5757) -> requests .Response :
58- try :
59- r : requests .Response | None = None
60- for _attempt in range (3 ):
58+ num_retries = 2
59+ r : requests .Response | None = None
60+ for _attempt in range (num_retries + 1 ):
61+ try :
6162 r = requests .get (file_url , stream = True )
62- if r .ok and response_ok (r ):
63- return r
64- assert r is not None
65- raise FailedDownloadError (
66- f"got response with status={ r .status_code } , retries exhausted"
67- )
68- except requests .RequestException as e :
69- raise FailedDownloadError ("encountered error during download" ) from e
63+ except requests .RequestException as e :
64+ if _attempt == num_retries :
65+ raise FailedDownloadError ("encountered error during download" ) from e
66+ continue
67+ if r .ok and response_ok (r ):
68+ return r
69+ assert r is not None
70+ raise FailedDownloadError (
71+ f"got response with status={ r .status_code } , retries exhausted"
72+ )
7073
7174
7275def _atomic_write (dest : str , content : bytes ) -> None :
@@ -97,27 +100,19 @@ class FailedDownloadError(Exception):
97100
98101
99102class CacheDownloader :
100- def __init__ (
101- self ,
102- cache_dir : str | None = None ,
103- disable_cache : bool = False ,
104- validation_callback : t .Callable [[bytes ], t .Any ] | None = None ,
105- ):
106- self ._cache_dir = cache_dir or _get_default_cache_dir ()
103+ def __init__ (self , cache_dir : str | None = None , disable_cache : bool = False ):
104+ if cache_dir is None :
105+ self ._cache_dir = _resolve_cache_dir ()
106+ else :
107+ self ._cache_dir = _resolve_cache_dir (cache_dir )
107108 self ._disable_cache = disable_cache
108- self ._validation_callback = validation_callback
109-
110- def _validate (self , response : requests .Response ) -> bool :
111- if not self ._validation_callback :
112- return True
113-
114- try :
115- self ._validation_callback (response .content )
116- return True
117- except ValueError :
118- return False
119109
120- def _download (self , file_url : str , filename : str ) -> str :
110+ def _download (
111+ self ,
112+ file_url : str ,
113+ filename : str ,
114+ response_ok : t .Callable [[requests .Response ], bool ],
115+ ) -> str :
121116 assert self ._cache_dir is not None
122117 os .makedirs (self ._cache_dir , exist_ok = True )
123118 dest = os .path .join (self ._cache_dir , filename )
@@ -129,7 +124,7 @@ def check_response_for_download(r: requests.Response) -> bool:
129124 if _cache_hit (dest , r ):
130125 return True
131126 # we now know it's not a hit, so validate the content (forces download)
132- return self . _validate (r )
127+ return response_ok (r )
133128
134129 response = _get_request (file_url , response_ok = check_response_for_download )
135130 # check to see if we have a file which matches the connection
@@ -140,15 +135,31 @@ def check_response_for_download(r: requests.Response) -> bool:
140135 return dest
141136
142137 @contextlib .contextmanager
143- def open (self , file_url : str , filename : str ) -> t .Iterator [t .IO [bytes ]]:
138+ def open (
139+ self ,
140+ file_url : str ,
141+ filename : str ,
142+ validate_response : t .Callable [[requests .Response ], bool ],
143+ ) -> t .Iterator [t .IO [bytes ]]:
144144 if (not self ._cache_dir ) or self ._disable_cache :
145- yield io .BytesIO (_get_request (file_url , response_ok = self ._validate ).content )
145+ yield io .BytesIO (
146+ _get_request (file_url , response_ok = validate_response ).content
147+ )
146148 else :
147- with open (self ._download (file_url , filename ), "rb" ) as fp :
149+ with open (
150+ self ._download (file_url , filename , response_ok = validate_response ), "rb"
151+ ) as fp :
148152 yield fp
149153
150- def bind (self , file_url : str , filename : str | None = None ) -> BoundCacheDownloader :
151- return BoundCacheDownloader (file_url , filename , self )
154+ def bind (
155+ self ,
156+ file_url : str ,
157+ filename : str | None = None ,
158+ validation_callback : t .Callable [[bytes ], t .Any ] | None = None ,
159+ ) -> BoundCacheDownloader :
160+ return BoundCacheDownloader (
161+ file_url , filename , self , validation_callback = validation_callback
162+ )
152163
153164
154165class BoundCacheDownloader :
@@ -157,12 +168,29 @@ def __init__(
157168 file_url : str ,
158169 filename : str | None ,
159170 downloader : CacheDownloader ,
171+ * ,
172+ validation_callback : t .Callable [[bytes ], t .Any ] | None = None ,
160173 ):
161174 self ._file_url = file_url
162175 self ._filename = filename or file_url .split ("/" )[- 1 ]
163176 self ._downloader = downloader
177+ self ._validation_callback = validation_callback
164178
165179 @contextlib .contextmanager
166180 def open (self ) -> t .Iterator [t .IO [bytes ]]:
167- with self ._downloader .open (self ._file_url , self ._filename ) as fp :
181+ with self ._downloader .open (
182+ self ._file_url ,
183+ self ._filename ,
184+ validate_response = self ._validate_response ,
185+ ) as fp :
168186 yield fp
187+
188+ def _validate_response (self , response : requests .Response ) -> bool :
189+ if not self ._validation_callback :
190+ return True
191+
192+ try :
193+ self ._validation_callback (response .content )
194+ return True
195+ except ValueError :
196+ return False
0 commit comments