|
13 | 13 | import pickle |
14 | 14 | import sys |
15 | 15 | import tarfile |
| 16 | +import gzip |
16 | 17 | import zipfile |
17 | 18 | from pathlib import Path |
18 | 19 | from typing import Callable, Optional, Tuple, Union |
@@ -165,6 +166,36 @@ def iterate_images(): |
165 | 166 |
|
166 | 167 | #---------------------------------------------------------------------------- |
167 | 168 |
|
| 169 | +def open_mnist(images_gz: str, *, max_images: Optional[int]): |
| 170 | + labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz') |
| 171 | + assert labels_gz != images_gz |
| 172 | + images = [] |
| 173 | + labels = [] |
| 174 | + |
| 175 | + with gzip.open(images_gz, 'rb') as f: |
| 176 | + images = np.frombuffer(f.read(), np.uint8, offset=16) |
| 177 | + with gzip.open(labels_gz, 'rb') as f: |
| 178 | + labels = np.frombuffer(f.read(), np.uint8, offset=8) |
| 179 | + |
| 180 | + images = images.reshape(-1, 28, 28) |
| 181 | + images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) |
| 182 | + assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 |
| 183 | + assert labels.shape == (60000,) and labels.dtype == np.uint8 |
| 184 | + assert np.min(images) == 0 and np.max(images) == 255 |
| 185 | + assert np.min(labels) == 0 and np.max(labels) == 9 |
| 186 | + |
| 187 | + max_idx = maybe_min(len(images), max_images) |
| 188 | + |
| 189 | + def iterate_images(): |
| 190 | + for idx, img in enumerate(images): |
| 191 | + yield dict(img=img, label=int(labels[idx])) |
| 192 | + if idx >= max_idx-1: |
| 193 | + break |
| 194 | + |
| 195 | + return max_idx, iterate_images() |
| 196 | + |
| 197 | +#---------------------------------------------------------------------------- |
| 198 | + |
168 | 199 | def make_transform( |
169 | 200 | transform: Optional[str], |
170 | 201 | output_width: Optional[int], |
@@ -225,10 +256,11 @@ def open_dataset(source, *, max_images: Optional[int]): |
225 | 256 | else: |
226 | 257 | return open_image_folder(source, max_images=max_images) |
227 | 258 | elif os.path.isfile(source): |
228 | | - if source.endswith('cifar-10-python.tar.gz'): |
| 259 | + if os.path.basename(source) == 'cifar-10-python.tar.gz': |
229 | 260 | return open_cifar10(source, max_images=max_images) |
230 | | - ext = file_ext(source) |
231 | | - if ext == 'zip': |
| 261 | + elif os.path.basename(source) == 'train-images-idx3-ubyte.gz': |
| 262 | + return open_mnist(source, max_images=max_images) |
| 263 | + elif file_ext(source) == 'zip': |
232 | 264 | return open_image_zip(source, max_images=max_images) |
233 | 265 | else: |
234 | 266 | assert False, 'unknown archive type' |
@@ -293,17 +325,18 @@ def convert_dataset( |
293 | 325 | The input dataset format is guessed from the --source argument: |
294 | 326 |
|
295 | 327 | \b |
296 | | - --source *_lmdb/ - Load LSUN dataset |
297 | | - --source cifar-10-python.tar.gz - Load CIFAR-10 dataset |
298 | | - --source path/ - Recursively load all images from path/ |
299 | | - --source dataset.zip - Recursively load all images from dataset.zip |
| 328 | + --source *_lmdb/ Load LSUN dataset |
| 329 | + --source cifar-10-python.tar.gz Load CIFAR-10 dataset |
| 330 | + --source train-images-idx3-ubyte.gz Load MNIST dataset |
| 331 | + --source path/ Recursively load all images from path/ |
| 332 | + --source dataset.zip Recursively load all images from dataset.zip |
300 | 333 |
|
301 | | - The output dataset format can be either an image folder or a zip archive. Specifying |
302 | | - the output format and path: |
| 334 | + The output dataset format can be either an image folder or a zip archive. |
| 335 | + Specifying the output format and path: |
303 | 336 |
|
304 | 337 | \b |
305 | | - --dest /path/to/dir - Save output files under /path/to/dir |
306 | | - --dest /path/to/dataset.zip - Save output files into /path/to/dataset.zip archive |
| 338 | + --dest /path/to/dir Save output files under /path/to/dir |
| 339 | + --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip |
307 | 340 |
|
308 | 341 | Images within the dataset archive will be stored as uncompressed PNG. |
309 | 342 |
|
|
0 commit comments