@@ -34,8 +34,8 @@ def str2bool(v):
3434parser .add_argument ('--weight_decay' , default = 5e-4 , type = float , help = 'Weight decay for SGD' )
3535parser .add_argument ('--gamma' , default = 0.1 , type = float , help = 'Gamma update for SGD' )
3636parser .add_argument ('--log_iters' , default = True , type = bool , help = 'Print the loss at each iteration' )
37- parser .add_argument ('--visdom' , default = True , type = str2bool , help = 'Use visdom for loss visualization' )
38- parser .add_argument ('--send_images_to_visdom' , type = str2bool , default = True , help = 'Sample a random image from every 10th batch, send it to visdom after augmentations step' )
37+ parser .add_argument ('--visdom' , default = False , type = str2bool , help = 'Use visdom for loss visualization' )
38+ parser .add_argument ('--send_images_to_visdom' , type = str2bool , default = False , help = 'Sample a random image from every 10th batch, send it to visdom after augmentations step' )
3939parser .add_argument ('--save_folder' , default = 'weights/' , help = 'Directory for saving checkpoint models' )
4040parser .add_argument ('--dataset_root' , default = COCO_ROOT , help = 'Dataset root directory path' )
4141parser .add_argument ('-f' , default = None , type = str , help = "Dummy arg so we can load in Jupyter Notebooks" )
@@ -104,63 +104,41 @@ def weights_init(m):
104104def train ():
105105 net .train ()
106106 # loss counters
107- loc_loss = 0 # epoch
107+ loc_loss = 0
108108 conf_loss = 0
109109 epoch = 0
110110 print ('Loading Dataset...' )
111111 dataset = COCODetection (args .dataset_root , args .image_set , SSDAugmentation (
112112 SSD_DIM , MEANS ), COCOAnnotationTransform ())
113113
114114 epoch_size = len (dataset ) // args .batch_size
115- print ('Training SSD on ' , dataset .name )
115+ print ('Training SSD on' , dataset .name )
116116 step_index = 0
117+
117118 if args .visdom :
118- # initialize visdom loss plot
119- lot = viz .line (
120- X = torch .zeros ((1 ,)).cpu (),
121- Y = torch .zeros ((1 , 3 )).cpu (),
122- opts = dict (
123- xlabel = 'Iteration' ,
124- ylabel = 'Loss' ,
125- title = 'Current SSD Training Loss' ,
126- legend = ['Loc Loss' , 'Conf Loss' , 'Loss' ]
127- )
128- )
129- epoch_lot = viz .line (
130- X = torch .zeros ((1 ,)).cpu (),
131- Y = torch .zeros ((1 , 3 )).cpu (),
132- opts = dict (
133- xlabel = 'Epoch' ,
134- ylabel = 'Loss' ,
135- title = 'Epoch SSD Training Loss' ,
136- legend = ['Loc Loss' , 'Conf Loss' , 'Loss' ]
137- )
138- )
139- batch_iterator = None
119+ vis_title = 'SSD.PyTorch on ' + args .image_set
120+ vis_legend = ['Loc Loss' , 'Conf Loss' , 'Total Loss' ]
121+ iter_plot = create_vis_plot ('Iteration' , 'Loss' , vis_title , vis_legend )
122+ epoch_plot = create_vis_plot ('Epoch' , 'Loss' , vis_title , vis_legend )
140123 data_loader = data .DataLoader (dataset , args .batch_size ,
141124 num_workers = args .num_workers ,
142125 shuffle = True , collate_fn = detection_collate ,
143126 pin_memory = True )
127+ # create batch iterator
128+ batch_iterator = iter (data_loader )
144129 for iteration in range (args .start_iter , args .max_iter ):
145- if (not batch_iterator ) or (iteration % epoch_size == 0 ):
146- # create batch iterator
147- batch_iterator = iter (data_loader )
148- if iteration in STEP_VALUES :
149- step_index += 1
150- adjust_learning_rate (optimizer , args .gamma , step_index )
151- if args .visdom :
152- viz .line (
153- X = torch .ones ((1 , 3 )).cpu () * epoch ,
154- Y = torch .Tensor ([loc_loss , conf_loss ,
155- loc_loss + conf_loss ]).unsqueeze (0 ).cpu () / epoch_size ,
156- win = epoch_lot ,
157- update = 'append'
158- )
130+ if iteration != 0 and (iteration % epoch_size == 0 ) and args .visdom :
131+ update_vis_plot (epoch , loc_loss , conf_loss , epoch_plot , None ,
132+ 'append' , epoch_size )
159133 # reset epoch loss counters
160134 loc_loss = 0
161135 conf_loss = 0
162136 epoch += 1
163137
138+ if iteration in STEP_VALUES :
139+ step_index += 1
140+ adjust_learning_rate (optimizer , args .gamma , step_index )
141+
164142 # load train data
165143 images , targets = next (batch_iterator )
166144
@@ -182,29 +160,15 @@ def train():
182160 t1 = time .time ()
183161 loc_loss += loss_l .data [0 ]
184162 conf_loss += loss_c .data [0 ]
163+
185164 if iteration % 10 == 0 :
186- print ('Timer : %.4f sec.' % (t1 - t0 ))
165+ print ('timer : %.4f sec.' % (t1 - t0 ))
187166 print ('iter ' + repr (iteration ) + ' || Loss: %.4f ||' % (loss .data [0 ]), end = ' ' )
188- if args .visdom and args .send_images_to_visdom :
189- random_batch_index = np .random .randint (images .size (0 ))
190- viz .image (images .data [random_batch_index ].cpu ().numpy ())
167+
191168 if args .visdom :
192- viz .line (
193- X = torch .ones ((1 , 3 )).cpu () * iteration ,
194- Y = torch .Tensor ([loss_l .data [0 ], loss_c .data [0 ],
195- loss_l .data [0 ] + loss_c .data [0 ]]).unsqueeze (0 ).cpu (),
196- win = lot ,
197- update = 'append'
198- )
199- # hacky fencepost solution for 0th epoch plot
200- if iteration == 0 :
201- viz .line (
202- X = torch .zeros ((1 , 3 )).cpu (),
203- Y = torch .Tensor ([loc_loss , conf_loss ,
204- loc_loss + conf_loss ]).unsqueeze (0 ).cpu (),
205- win = epoch_lot ,
206- update = True
207- )
169+ update_vis_plot (iteration , loss_l .data [0 ], loss_c .data [0 ],
170+ iter_plot , epoch_plot , 'append' )
171+
208172 if iteration % 5000 == 0 :
209173 print ('Saving state, iter:' , iteration )
210174 torch .save (ssd_net .state_dict (), 'weights/ssd300_COCO_' +
@@ -224,5 +188,36 @@ def adjust_learning_rate(optimizer, gamma, step):
224188 param_group ['lr' ] = lr
225189
226190
191+ def create_vis_plot (_xlabel , _ylabel , _title , _legend ):
192+ return viz .line (
193+ X = torch .zeros ((1 ,)).cpu (),
194+ Y = torch .zeros ((1 , 3 )).cpu (),
195+ opts = dict (
196+ xlabel = _xlabel ,
197+ ylabel = _ylabel ,
198+ title = _title ,
199+ legend = _legend
200+ )
201+ )
202+
203+
204+ def update_vis_plot (iteration , loc , conf , window1 , window2 , update_type ,
205+ epoch_size = 1 ):
206+ viz .line (
207+ X = torch .ones ((1 , 3 )).cpu () * iteration ,
208+ Y = torch .Tensor ([loc , conf , loc + conf ]).unsqueeze (0 ).cpu () / epoch_size ,
209+ win = window1 ,
210+ update = update_type
211+ )
212+ # initialize epoch plot on first iteration
213+ if iteration == 0 :
214+ viz .line (
215+ X = torch .zeros ((1 , 3 )).cpu (),
216+ Y = torch .Tensor ([loc , conf , loc + conf ]).unsqueeze (0 ).cpu (),
217+ win = window2 ,
218+ update = True
219+ )
220+
221+
227222if __name__ == '__main__' :
228223 train ()
0 commit comments