3333from etils import epath
3434from tensorflow_datasets .core import units
3535from tensorflow_datasets .core import utils
36+ from tensorflow_datasets .core import lazy_imports_lib
3637from tensorflow_datasets .core .download import checksums as checksums_lib
3738from tensorflow_datasets .core .download import resource as resource_lib
3839from tensorflow_datasets .core .download import util as download_utils_lib
@@ -130,6 +131,44 @@ def _get_filename(response: Response) -> str:
130131 return _basename_from_url (response .url )
131132
132133
134+ def _process_gdrive_confirmation (original_url : str , contents : str ) -> str :
135+ """Process Google Drive confirmation page.
136+
137+ Extracts the download link from a Google Drive confirmation page.
138+
139+ Args:
140+ original_url: The URL the confirmation page was originally
141+ retrieved from.
142+ contents: The confirmation page's HTML.
143+
144+ Returns:
145+ download_url: The URL for downloading the file.
146+ """
147+ bs4 = lazy_imports_lib .lazy_imports .bs4
148+ soup = bs4 .BeautifulSoup (contents , 'html.parser' )
149+ form = soup .find ('form' )
150+ if not form :
151+ raise ValueError (
152+ f'Failed to obtain confirmation link for GDrive URL { original_url } .'
153+ )
154+ action = form .get ('action' , '' )
155+ if not action :
156+ raise ValueError (
157+ f'Failed to obtain confirmation link for GDrive URL { original_url } .'
158+ )
159+ # Find the <input>s named 'uuid', 'export', 'id' and 'confirm'
160+ input_names = ['uuid' , 'export' , 'id' , 'confirm' ]
161+ params = {}
162+ for name in input_names :
163+ input_tag = form .find ('input' , {'name' : name })
164+ if input_tag :
165+ params [name ] = input_tag .get ('value' , '' )
166+ query_string = urllib .parse .urlencode (params )
167+ download_url = f'{ action } ?{ query_string } ' if query_string else action
168+ download_url = urllib .parse .urljoin (original_url , download_url )
169+ return download_url
170+
171+
133172class _Downloader :
134173 """Class providing async download API with checksum validation.
135174
@@ -318,11 +357,15 @@ def _open_with_requests(
318357 session .mount (
319358 'https://' , requests .adapters .HTTPAdapter (max_retries = retries )
320359 )
321- if _DRIVE_URL .match (url ):
322- url = _normalize_drive_url (url )
323360 with session .get (url , stream = True , ** kwargs ) as response :
324- _assert_status (response )
325- yield (response , response .iter_content (chunk_size = io .DEFAULT_BUFFER_SIZE ))
361+ if _DRIVE_URL .match (url ) and 'Content-Disposition' not in response .headers :
362+ download_url = _process_gdrive_confirmation (url , response .text )
363+ with session .get (download_url , stream = True , ** kwargs ) as download_response :
364+ _assert_status (download_response )
365+ yield (download_response , download_response .iter_content (chunk_size = io .DEFAULT_BUFFER_SIZE ))
366+ else :
367+ _assert_status (response )
368+ yield (response , response .iter_content (chunk_size = io .DEFAULT_BUFFER_SIZE ))
326369
327370
328371@contextlib .contextmanager
@@ -338,13 +381,6 @@ def _open_with_urllib(
338381 )
339382
340383
341- def _normalize_drive_url (url : str ) -> str :
342- """Returns Google Drive url with confirmation token."""
343- # This bypasses the "Google Drive can't scan this file for viruses" warning
344- # when dowloading large files.
345- return url + '&confirm=t'
346-
347-
348384def _assert_status (response : requests .Response ) -> None :
349385 """Ensure the URL response is 200."""
350386 if response .status_code != 200 :
0 commit comments