11#include < torch/torch.h>
2-
2+ # include < argparse/argparse.hpp >
33#include < cmath>
44#include < cstdio>
55#include < iostream>
@@ -10,9 +10,6 @@ const int64_t kNoiseSize = 100;
1010// The batch size for training.
1111const int64_t kBatchSize = 64 ;
1212
13- // The number of epochs to train.
14- const int64_t kNumberOfEpochs = 30 ;
15-
1613// Where to find the MNIST dataset.
1714const char * kDataFolder = " ./data" ;
1815
@@ -75,7 +72,43 @@ struct DCGANGeneratorImpl : nn::Module {
7572
7673TORCH_MODULE (DCGANGenerator);
7774
75+ nn::Sequential create_discriminator () {
76+ return nn::Sequential (
77+ // Layer 1
78+ nn::Conv2d (nn::Conv2dOptions (1 , 64 , 4 ).stride (2 ).padding (1 ).bias (false )),
79+ nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
80+ // Layer 2
81+ nn::Conv2d (nn::Conv2dOptions (64 , 128 , 4 ).stride (2 ).padding (1 ).bias (false )),
82+ nn::BatchNorm2d (128 ),
83+ nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
84+ // Layer 3
85+ nn::Conv2d (
86+ nn::Conv2dOptions (128 , 256 , 4 ).stride (2 ).padding (1 ).bias (false )),
87+ nn::BatchNorm2d (256 ),
88+ nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
89+ // Layer 4
90+ nn::Conv2d (nn::Conv2dOptions (256 , 1 , 3 ).stride (1 ).padding (0 ).bias (false )),
91+ nn::Sigmoid ());
92+ }
93+
7894int main (int argc, const char * argv[]) {
95+ argparse::ArgumentParser parser (" cpp/dcgan example" );
96+ parser.add_argument (" --epochs" )
97+ .help (" Number of epochs to train" )
98+ .default_value (std::int64_t {30 })
99+ .scan <' i' , int64_t >();
100+ try {
101+ parser.parse_args (argc, argv);
102+ } catch (const std::exception& err) {
103+ std::cout << err.what () << std::endl;
104+ std::cout << parser;
105+ std::exit (1 );
106+ }
107+ // The number of epochs to train, default value is 30.
108+ const int64_t kNumberOfEpochs = parser.get <int64_t >(" --epochs" );
109+ std::cout << " Traning with number of epochs: " << kNumberOfEpochs
110+ << std::endl;
111+
79112 torch::manual_seed (1 );
80113
81114 // Create the device we pass around based on whether CUDA is available.
@@ -88,33 +121,15 @@ int main(int argc, const char* argv[]) {
88121 DCGANGenerator generator (kNoiseSize );
89122 generator->to (device);
90123
91- nn::Sequential discriminator (
92- // Layer 1
93- nn::Conv2d (
94- nn::Conv2dOptions (1 , 64 , 4 ).stride (2 ).padding (1 ).bias (false )),
95- nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
96- // Layer 2
97- nn::Conv2d (
98- nn::Conv2dOptions (64 , 128 , 4 ).stride (2 ).padding (1 ).bias (false )),
99- nn::BatchNorm2d (128 ),
100- nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
101- // Layer 3
102- nn::Conv2d (
103- nn::Conv2dOptions (128 , 256 , 4 ).stride (2 ).padding (1 ).bias (false )),
104- nn::BatchNorm2d (256 ),
105- nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
106- // Layer 4
107- nn::Conv2d (
108- nn::Conv2dOptions (256 , 1 , 3 ).stride (1 ).padding (0 ).bias (false )),
109- nn::Sigmoid ());
124+ nn::Sequential discriminator = create_discriminator ();
110125 discriminator->to (device);
111126
112127 // Assume the MNIST dataset is available under `kDataFolder`;
113128 auto dataset = torch::data::datasets::MNIST (kDataFolder )
114129 .map (torch::data::transforms::Normalize<>(0.5 , 0.5 ))
115130 .map (torch::data::transforms::Stack<>());
116- const int64_t batches_per_epoch =
117- std::ceil (dataset.size ().value () / static_cast <double >(kBatchSize ));
131+ const int64_t batches_per_epoch = static_cast < int64_t >(
132+ std::ceil (dataset.size ().value () / static_cast <double >(kBatchSize ))) ;
118133
119134 auto data_loader = torch::data::make_data_loader (
120135 std::move (dataset),
@@ -136,7 +151,7 @@ int main(int argc, const char* argv[]) {
136151 int64_t checkpoint_counter = 1 ;
137152 for (int64_t epoch = 1 ; epoch <= kNumberOfEpochs ; ++epoch) {
138153 int64_t batch_index = 0 ;
139- for (torch::data::Example<>& batch : *data_loader) {
154+ for (const torch::data::Example<>& batch : *data_loader) {
140155 // Train discriminator with real images.
141156 discriminator->zero_grad ();
142157 torch::Tensor real_images = batch.data .to (device);
0 commit comments