@@ -114,16 +114,27 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None]
114114 raise Exception (f"Failed to read dataset at { url } " ) from None
115115
116116
117- def _verify_files_dont_exist (paths : Iterable [Union [str , Path ]]) -> None :
117+ def _verify_files_dont_exist (
118+ paths : Iterable [Union [str , Path ]], remove_if_exist : bool = False
119+ ) -> None :
118120 """
119121 Verifies all paths in 'paths' don't exist.
120122 :param paths: A iterable of strs or pathlib.Paths.
123+ :param remove_if_exist=False: Removes file at path if they already exist.
121124 :returns: None
122125 :raises FileExistsError: On the first path found that already exists.
123126 """
124127 for path in paths :
125- if Path (path ).exists ():
126- raise FileExistsError (f"Error: File '{ path } ' already exists." )
128+ path = Path (path )
129+ if path .exists ():
130+ if remove_if_exist :
131+ if path .is_symlink ():
132+ realpath = path .resolve ()
133+ path .unlink (realpath )
134+ else :
135+ shutil .rmtree (path )
136+ else :
137+ raise FileExistsError (f"Error: File '{ path } ' already exists." )
127138
128139
129140def _is_file_to_symlink (path : Path ) -> bool :
@@ -188,7 +199,9 @@ async def read(url: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> bytes:
188199 return b"" .join ([chunk async for chunk in _get_chunks (url , chunk_size )])
189200
190201
191- async def prepare (url : str , path : Optional [str ] = None , verbose : bool = True ) -> None :
202+ async def prepare (
203+ url : str , path : Optional [str ] = None , verbose : bool = True , overwrite : bool = False
204+ ) -> None :
192205 """
193206 Prepares a dataset for learners. Downloads a dataset from the given url,
194207 decompresses it if necessary. If not using jupyterlite, will extract to
@@ -200,6 +213,8 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
200213
201214 :param url: The URL to download the dataset from.
202215 :param path: The path the dataset will be available at. Current working directory by default.
216+ :param verbose=True: Prints saved path if True.
217+ :param overwrite=False: Overwrites any existing files at destination if they exist.
203218 :raise InvalidURLException: When URL is invalid.
204219 :raise FileExistsError: it raises this when a file to be symlinked already exists.
205220 :raise ValueError: When requested path is in /tmp, or cannot be saved to path.
@@ -239,7 +254,8 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
239254 path / child .name
240255 for child in map (Path , tf .getnames ())
241256 if len (child .parents ) == 1 and _is_file_to_symlink (child )
242- ]
257+ ],
258+ overwrite ,
243259 ) # Only check if top-level fileobject
244260 pbar = tqdm (iterable = tf .getmembers (), total = len (tf .getmembers ()))
245261 pbar .set_description (f"Extracting { filename } " )
@@ -253,15 +269,16 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
253269 path / child .name
254270 for child in map (Path , zf .namelist ())
255271 if len (child .parents ) == 1 and _is_file_to_symlink (child )
256- ]
272+ ],
273+ overwrite ,
257274 )
258275 pbar = tqdm (iterable = zf .infolist (), total = len (zf .infolist ()))
259276 pbar .set_description (f"Extracting { filename } " )
260277 for member in pbar :
261278 zf .extract (member = member , path = extract_dir )
262279 tmp_download_file .unlink ()
263280 else :
264- _verify_files_dont_exist ([path / filename ])
281+ _verify_files_dont_exist ([path / filename ], overwrite )
265282 shutil .move (tmp_download_file , extract_dir / filename )
266283
267284 # If in jupyterlite environment, the extract_dir = path, so the files are already there.
@@ -274,8 +291,36 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
274291 print (f"Saved to '{ relpath (path .resolve ())} '" )
275292
276293
277- if _is_jupyterlite ():
278- tqdm .monitor_interval = 0
294+ def setup () -> None :
295+ if _is_jupyterlite ():
296+ tqdm .monitor_interval = 0
297+
298+ try :
299+ import sys # pyright: ignore
300+
301+ ipython = get_ipython ()
302+
303+ def hide_traceback (
304+ exc_tuple = None ,
305+ filename = None ,
306+ tb_offset = None ,
307+ exception_only = False ,
308+ running_compiled_code = False ,
309+ ):
310+ etype , value , tb = sys .exc_info ()
311+ value .__cause__ = None # suppress chained exceptions
312+ return ipython ._showtraceback (
313+ etype , value , ipython .InteractiveTB .get_exception_only (etype , value )
314+ )
315+
316+ ipython .showtraceback = hide_traceback
317+
318+ except NameError :
319+ pass
320+
321+
322+ setup ()
323+
279324
280325# For backwards compatibility
281326download_dataset = download
0 commit comments