1+ # Copyright 2022 IBM, Red Hat
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ import os
16+ from torchvision .datasets import MNIST
17+ from torchvision import transforms
18+
19+ def download_mnist_dataset (destination_dir ):
20+ # Ensure the destination directory exists
21+ if not os .path .exists (destination_dir ):
22+ os .makedirs (destination_dir )
23+
24+ # Define transformations
25+ transform = transforms .Compose ([
26+ transforms .ToTensor (),
27+ transforms .Normalize ((0.1307 ,), (0.3081 ,))
28+ ])
29+
30+ # Download the training data
31+ train_set = MNIST (root = destination_dir , train = True , download = True , transform = transform )
32+
33+ # Download the test data
34+ test_set = MNIST (root = destination_dir , train = False , download = True , transform = transform )
35+
36+ print (f"MNIST dataset downloaded in { destination_dir } " )
37+
38+ # Specify the directory where you
39+ script_dir = os .path .dirname (os .path .abspath (__file__ ))
40+ destination_dir = script_dir + "/mnist_datasets"
41+
42+ download_mnist_dataset (destination_dir )
0 commit comments