Skip to content

Commit 84ff674

Browse files
committed
Add script to download the data
1 parent c4d9d93 commit 84ff674

File tree

3 files changed

+94
-3
lines changed

3 files changed

+94
-3
lines changed

advanced_source/cpp_cuda_graphs/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON)
99
if (DOWNLOAD_MNIST)
1010
message(STATUS "Downloading MNIST dataset")
1111
execute_process(
12-
COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py
12+
COMMAND python ${CMAKE_CURRENT_LIST_DIR}/download_mnist.py
1313
-d ${CMAKE_BINARY_DIR}/data
1414
ERROR_VARIABLE DOWNLOAD_ERROR)
1515
if (DOWNLOAD_ERROR)

advanced_source/cpp_cuda_graphs/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ $ make
1616
```
1717

1818
where `/path/to/libtorch` should be the path to the unzipped _LibTorch_
19-
distribution, which you can get from the [PyTorch
20-
homepage](https://pytorch.org/get-started/locally/).
19+
distribution or PyTorch's CMake prefix path
20+
`python -c "import torch; print(torch.utils.cmake_prefix_path)"`.
21+
Please see [PyTorch homepage](https://pytorch.org/get-started/locally/)
22+
for installation instructions.
2123

2224
Execute the compiled binary to train the model:
2325

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from __future__ import division
2+
from __future__ import print_function
3+
4+
import argparse
5+
import gzip
6+
import os
7+
import sys
8+
import urllib
9+
10+
try:
11+
from urllib.error import URLError
12+
from urllib.request import urlretrieve
13+
except ImportError:
14+
from urllib2 import URLError
15+
from urllib import urlretrieve
16+
17+
RESOURCES = [
18+
'train-images-idx3-ubyte.gz',
19+
'train-labels-idx1-ubyte.gz',
20+
't10k-images-idx3-ubyte.gz',
21+
't10k-labels-idx1-ubyte.gz',
22+
]
23+
24+
25+
def report_download_progress(chunk_number, chunk_size, file_size):
26+
if file_size != -1:
27+
percent = min(1, (chunk_number * chunk_size) / file_size)
28+
bar = '#' * int(64 * percent)
29+
sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100)))
30+
31+
32+
def download(destination_path, url, quiet):
33+
if os.path.exists(destination_path):
34+
if not quiet:
35+
print('{} already exists, skipping ...'.format(destination_path))
36+
else:
37+
print('Downloading {} ...'.format(url))
38+
try:
39+
hook = None if quiet else report_download_progress
40+
urlretrieve(url, destination_path, reporthook=hook)
41+
except URLError:
42+
raise RuntimeError('Error downloading resource!')
43+
finally:
44+
if not quiet:
45+
# Just a newline.
46+
print()
47+
48+
49+
def unzip(zipped_path, quiet):
50+
unzipped_path = os.path.splitext(zipped_path)[0]
51+
if os.path.exists(unzipped_path):
52+
if not quiet:
53+
print('{} already exists, skipping ... '.format(unzipped_path))
54+
return
55+
with gzip.open(zipped_path, 'rb') as zipped_file:
56+
with open(unzipped_path, 'wb') as unzipped_file:
57+
unzipped_file.write(zipped_file.read())
58+
if not quiet:
59+
print('Unzipped {} ...'.format(zipped_path))
60+
61+
62+
def main():
63+
parser = argparse.ArgumentParser(
64+
description='Download the MNIST dataset from the internet')
65+
parser.add_argument(
66+
'-d', '--destination', default='.', help='Destination directory')
67+
parser.add_argument(
68+
'-q',
69+
'--quiet',
70+
action='store_true',
71+
help="Don't report about progress")
72+
options = parser.parse_args()
73+
74+
if not os.path.exists(options.destination):
75+
os.makedirs(options.destination)
76+
77+
try:
78+
for resource in RESOURCES:
79+
path = os.path.join(options.destination, resource)
80+
# url = 'http://yann.lecun.com/exdb/mnist/{}'.format(resource)
81+
url = 'https://ossci-datasets.s3.amazonaws.com/mnist/{}'.format(resource)
82+
download(path, url, options.quiet)
83+
unzip(path, options.quiet)
84+
except KeyboardInterrupt:
85+
print('Interrupted')
86+
87+
88+
if __name__ == '__main__':
89+
main()

0 commit comments

Comments
 (0)