@@ -39,14 +39,14 @@ def forward(self, x):
3939 return F .log_softmax (x , dim = 1 )
4040
4141
42- def _get_train_data_loader (training_dir , is_distributed , ** kwargs ):
42+ def _get_train_data_loader (training_dir , is_distributed , batch_size , ** kwargs ):
4343 logger .info ('Get train data loader' )
4444 dataset = datasets .MNIST (training_dir , train = True , transform = transforms .Compose ([
4545 transforms .ToTensor (),
4646 transforms .Normalize ((0.1307 ,), (0.3081 ,))
4747 ]))
4848 train_sampler = torch .utils .data .distributed .DistributedSampler (dataset ) if is_distributed else None
49- train_loader = torch .utils .data .DataLoader (dataset , batch_size = 64 , shuffle = train_sampler is None ,
49+ train_loader = torch .utils .data .DataLoader (dataset , batch_size = batch_size , shuffle = train_sampler is None ,
5050 sampler = train_sampler , ** kwargs )
5151 return train_sampler , train_loader
5252
@@ -94,7 +94,7 @@ def train(args):
9494 if use_cuda :
9595 torch .cuda .manual_seed (seed )
9696
97- train_sampler , train_loader = _get_train_data_loader (args .data_dir , is_distributed , ** kwargs )
97+ train_sampler , train_loader = _get_train_data_loader (args .data_dir , is_distributed , args . batch_size , ** kwargs )
9898 test_loader = _get_test_data_loader (args .data_dir , ** kwargs )
9999
100100 logger .debug ('Processes {}/{} ({:.0f}%) of train data' .format (
@@ -142,9 +142,11 @@ def train(args):
142142 logger .debug ('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}' .format (
143143 epoch , batch_idx * len (data ), len (train_loader .sampler ),
144144 100. * batch_idx / len (train_loader ), loss .item ()))
145- test (model , test_loader , device )
145+ accuracy = test (model , test_loader , device )
146146 save_model (model , args .model_dir )
147147
148+ logger .debug ('Overall test accuracy: {}' .format (accuracy ))
149+
148150
149151def test (model , test_loader , device ):
150152 model .eval ()
@@ -159,9 +161,12 @@ def test(model, test_loader, device):
159161 correct += pred .eq (target .view_as (pred )).sum ().item ()
160162
161163 test_loss /= len (test_loader .dataset )
164+ accuracy = 100. * correct / len (test_loader .dataset )
165+
162166 logger .debug ('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n ' .format (
163- test_loss , correct , len (test_loader .dataset ),
164- 100. * correct / len (test_loader .dataset )))
167+ test_loss , correct , len (test_loader .dataset ), accuracy ))
168+
169+ return accuracy
165170
166171
167172def model_fn (model_dir ):
@@ -181,6 +186,7 @@ def save_model(model, model_dir):
181186if __name__ == '__main__' :
182187 parser = argparse .ArgumentParser ()
183188 parser .add_argument ('--epochs' , type = int , default = 1 , metavar = 'N' )
189+ parser .add_argument ('--batch-size' , type = int , default = 64 , metavar = 'N' )
184190
185191 # Container environment
186192 parser .add_argument ('--hosts' , type = list , default = json .loads (os .environ ['SM_HOSTS' ]))
0 commit comments