3131import urllib
3232
3333from etils import epath
34+ from tensorflow_datasets .core import lazy_imports_lib
3435from tensorflow_datasets .core import units
3536from tensorflow_datasets .core import utils
3637from tensorflow_datasets .core .download import checksums as checksums_lib
@@ -130,6 +131,43 @@ def _get_filename(response: Response) -> str:
130131 return _basename_from_url (response .url )
131132
132133
134+ def _process_gdrive_confirmation (original_url : str , contents : str ) -> str :
135+ """Process Google Drive confirmation page.
136+
137+ Extracts the download link from a Google Drive confirmation page.
138+
139+ Args:
140+ original_url: The URL the confirmation page was originally retrieved from.
141+ contents: The confirmation page's HTML.
142+
143+ Returns:
144+ download_url: The URL for downloading the file.
145+ """
146+ bs4 = lazy_imports_lib .lazy_imports .bs4
147+ soup = bs4 .BeautifulSoup (contents , 'html.parser' )
148+ form = soup .find ('form' )
149+ if not form :
150+ raise ValueError (
151+ f'Failed to obtain confirmation link for GDrive URL { original_url } .'
152+ )
153+ action = form .get ('action' , '' )
154+ if not action :
155+ raise ValueError (
156+ f'Failed to obtain confirmation link for GDrive URL { original_url } .'
157+ )
158+ # Find the <input>s named 'uuid', 'export', 'id' and 'confirm'
159+ input_names = ['uuid' , 'export' , 'id' , 'confirm' ]
160+ params = {}
161+ for name in input_names :
162+ input_tag = form .find ('input' , {'name' : name })
163+ if input_tag :
164+ params [name ] = input_tag .get ('value' , '' )
165+ query_string = urllib .parse .urlencode (params )
166+ download_url = f'{ action } ?{ query_string } ' if query_string else action
167+ download_url = urllib .parse .urljoin (original_url , download_url )
168+ return download_url
169+
170+
133171class _Downloader :
134172 """Class providing async download API with checksum validation.
135173
@@ -318,11 +356,26 @@ def _open_with_requests(
318356 session .mount (
319357 'https://' , requests .adapters .HTTPAdapter (max_retries = retries )
320358 )
321- if _DRIVE_URL .match (url ):
322- url = _normalize_drive_url (url )
323359 with session .get (url , stream = True , ** kwargs ) as response :
324- _assert_status (response )
325- yield (response , response .iter_content (chunk_size = io .DEFAULT_BUFFER_SIZE ))
360+ if (
361+ _DRIVE_URL .match (url )
362+ and 'Content-Disposition' not in response .headers
363+ ):
364+ download_url = _process_gdrive_confirmation (url , response .text )
365+ with session .get (
366+ download_url , stream = True , ** kwargs
367+ ) as download_response :
368+ _assert_status (download_response )
369+ yield (
370+ download_response ,
371+ download_response .iter_content (chunk_size = io .DEFAULT_BUFFER_SIZE ),
372+ )
373+ else :
374+ _assert_status (response )
375+ yield (
376+ response ,
377+ response .iter_content (chunk_size = io .DEFAULT_BUFFER_SIZE ),
378+ )
326379
327380
328381@contextlib .contextmanager
@@ -338,13 +391,6 @@ def _open_with_urllib(
338391 )
339392
340393
341- def _normalize_drive_url (url : str ) -> str :
342- """Returns Google Drive url with confirmation token."""
343- # This bypasses the "Google Drive can't scan this file for viruses" warning
344- # when dowloading large files.
345- return url + '&confirm=t'
346-
347-
348394def _assert_status (response : requests .Response ) -> None :
349395 """Ensure the URL response is 200."""
350396 if response .status_code != 200 :
0 commit comments