11import os
2+ import struct
23import logging
4+ from tqdm import tqdm
35
46
57logger = logging .getLogger (__name__ )
@@ -30,12 +32,35 @@ def download(url, path, save_file=None, md5=None):
3032 return save_file
3133
3234
35+ def smart_open (file_name , mode = "rb" ):
36+ """
37+ Open a regular file or a zipped file.
38+
39+ This function can be used as drop-in replacement of the builtin function `open()`.
40+
41+ Parameters:
42+ file_name (str): file name
43+ mode (str, optional): open mode for the file stream
44+ """
45+ import bz2
46+ import gzip
47+
48+ extension = os .path .splitext (file_name )[1 ]
49+ if extension == '.bz2' :
50+ return bz2 .BZ2File (file_name , mode )
51+ elif extension == '.gz' :
52+ return gzip .GzipFile (file_name , mode )
53+ else :
54+ return open (file_name , mode )
55+
56+
3357def extract (zip_file , member = None ):
3458 """
3559 Extract files from a zip file. Currently, ``zip``, ``gz``, ``tar.gz``, ``tar`` file types are supported.
3660
3761 Parameters:
38- member (str, optional): extract a specific member from the zip file.
62+ zip_file (str): file name
63+ member (str, optional): extract specific member from the zip file.
3964 If not specified, extract all members.
4065 """
4166 import gzip
@@ -47,40 +72,64 @@ def extract(zip_file, member=None):
4772 if zip_name .endswith (".tar" ):
4873 extension = ".tar" + extension
4974 zip_name = zip_name [:- 4 ]
50-
51- if member is None :
52- save_file = zip_name
53- else :
54- save_file = os .path .join (os .path .dirname (zip_name ), os .path .basename (member ))
55- if os .path .exists (save_file ):
56- return save_file
57-
58- if member is None :
59- logger .info ("Extracting %s to %s" % (zip_file , save_file ))
60- else :
61- logger .info ("Extracting %s from %s to %s" % (member , zip_file , save_file ))
75+ save_path = os .path .dirname (zip_file )
6276
6377 if extension == ".gz" :
64- with gzip .open (zip_file , "rb" ) as fin , open (save_file , "wb" ) as fout :
65- shutil .copyfileobj (fin , fout )
78+ member = os .path .basename (zip_name )
79+ members = [member ]
80+ save_files = [os .path .join (save_path , member )]
81+ for _member , save_file in zip (members , save_files ):
82+ with open (zip_file , "rb" ) as fin :
83+ fin .seek (- 4 , 2 )
84+ file_size = struct .unpack ("<I" , fin .read ())[0 ]
85+ with gzip .open (zip_file , "rb" ) as fin :
86+ if not os .path .exists (save_file ) or file_size != os .path .getsize (save_file ):
87+ logger .info ("Extracting %s to %s" % (zip_file , save_file ))
88+ with open (save_file , "wb" ) as fout :
89+ shutil .copyfileobj (fin , fout )
6690 elif extension in [".tar.gz" , ".tgz" , ".tar" ]:
67- if member is None :
68- with tarfile .open (zip_file , "r" ) as fin :
69- fin .extractall (save_file )
91+ tar = tarfile .open (zip_file , "r" )
92+ if member is not None :
93+ members = [member ]
94+ save_files = [os .path .join (save_path , os .path .basename (member ))]
95+ logger .info ("Extracting %s from %s to %s" % (member , zip_file , save_files [0 ]))
7096 else :
71- with tarfile .open (zip_file , "r" ).extractfile (member ) as fin , open (save_file , "wb" ) as fout :
72- shutil .copyfileobj (fin , fout )
97+ members = tar .getnames ()
98+ save_files = [os .path .join (save_path , _member ) for _member in members ]
99+ logger .info ("Extracting %s to %s" % (zip_file , save_path ))
100+ for _member , save_file in zip (members , save_files ):
101+ if tar .getmember (_member ).isdir ():
102+ os .makedirs (save_file , exist_ok = True )
103+ continue
104+ os .makedirs (os .path .dirname (save_file ), exist_ok = True )
105+ if not os .path .exists (save_file ) or tar .getmember (_member ).size != os .path .getsize (save_file ):
106+ with tar .extractfile (_member ) as fin , open (save_file , "wb" ) as fout :
107+ shutil .copyfileobj (fin , fout )
73108 elif extension == ".zip" :
74- if member is None :
75- with zipfile .ZipFile (zip_file ) as fin :
76- fin .extractall (save_file )
109+ zipped = zipfile .ZipFile (zip_file )
110+ if member is not None :
111+ members = [member ]
112+ save_files = [os .path .join (save_path , os .path .basename (member ))]
113+ logger .info ("Extracting %s from %s to %s" % (member , zip_file , save_files [0 ]))
77114 else :
78- with zipfile .ZipFile (zip_file ).open (member , "r" ) as fin , open (save_file , "wb" ) as fout :
79- shutil .copyfileobj (fin , fout )
115+ members = zipped .namelist ()
116+ save_files = [os .path .join (save_path , _member ) for _member in members ]
117+ logger .info ("Extracting %s to %s" % (zip_file , save_path ))
118+ for _member , save_file in zip (members , save_files ):
119+ if zipped .getinfo (_member ).is_dir ():
120+ os .makedirs (save_file , exist_ok = True )
121+ continue
122+ os .makedirs (os .path .dirname (save_file ), exist_ok = True )
123+ if not os .path .exists (save_file ) or zipped .getinfo (_member ).file_size != os .path .getsize (save_file ):
124+ with zipped .open (_member , "r" ) as fin , open (save_file , "wb" ) as fout :
125+ shutil .copyfileobj (fin , fout )
80126 else :
81127 raise ValueError ("Unknown file extension `%s`" % extension )
82128
83- return save_file
129+ if len (save_files ) == 1 :
130+ return save_files [0 ]
131+ else :
132+ return save_path
84133
85134
86135def compute_md5 (file_name , chunk_size = 65536 ):
0 commit comments