@@ -60,23 +60,18 @@ def _compute_default_cache_dir(self) -> str | None:
6060
6161 return cache_dir
6262
63- def _get_request (self ) -> requests .Response :
63+ def _get_request (
64+ self , * , response_ok : t .Callable [[requests .Response ], bool ]
65+ ) -> requests .Response :
6466 try :
65- # do manual retries, rather than using urllib3 retries, to make it trivially
66- # testable with 'responses'
6767 r : requests .Response | None = None
6868 for _attempt in range (3 ):
6969 r = requests .get (self ._file_url , stream = True )
70- if r .ok :
71- if self ._validation_callback is not None :
72- try :
73- self ._validation_callback (r .content )
74- except ValueError :
75- continue
70+ if r .ok and response_ok (r ):
7671 return r
7772 assert r is not None
7873 raise FailedDownloadError (
79- f"got responses with status={ r .status_code } , retries exhausted"
74+ f"got response with status={ r .status_code } , retries exhausted"
8075 )
8176 except requests .RequestException as e :
8277 raise FailedDownloadError ("encountered error during download" ) from e
@@ -113,12 +108,31 @@ def _write(self, dest: str, response: requests.Response) -> None:
113108 shutil .copy (fp .name , dest )
114109 os .remove (fp .name )
115110
111+ def _validate (self , response : requests .Response ) -> bool :
112+ if not self ._validation_callback :
113+ return True
114+
115+ try :
116+ self ._validation_callback (response .content )
117+ return True
118+ except ValueError :
119+ return False
120+
116121 def _download (self ) -> str :
117122 assert self ._cache_dir
118123 os .makedirs (self ._cache_dir , exist_ok = True )
119124 dest = os .path .join (self ._cache_dir , self ._filename )
120125
121- response = self ._get_request ()
126+ def check_response_for_download (r : requests .Response ) -> bool :
127+ # if the response indicates a cache hit, treat it as valid
128+ # this ensures that we short-circuit any further evaluation immediately on
129+ # a hit
130+ if self ._cache_hit (dest , r ):
131+ return True
132+ # we now know it's not a hit, so validate the content (forces download)
133+ return self ._validate (r )
134+
135+ response = self ._get_request (response_ok = check_response_for_download )
122136 # check to see if we have a file which matches the connection
123137 # only download if we do not (cache miss, vs hit)
124138 if not self ._cache_hit (dest , response ):
@@ -129,7 +143,7 @@ def _download(self) -> str:
129143 @contextlib .contextmanager
130144 def open (self ) -> t .Iterator [t .IO [bytes ]]:
131145 if (not self ._cache_dir ) or self ._disable_cache :
132- yield io .BytesIO (self ._get_request ().content )
146+ yield io .BytesIO (self ._get_request (response_ok = self . _validate ).content )
133147 else :
134148 with open (self ._download (), "rb" ) as fp :
135149 yield fp
0 commit comments