2929from pymic .transform .trans_dict import TransformDict
3030from pymic .util .post_process import PostProcessDict
3131from pymic .util .image_process import convert_label
32- from pymic .util .general import mixup
32+ from pymic .util .general import mixup , tensor_shape_match
3333
3434class SegmentationAgent (NetRunAgent ):
3535 def __init__ (self , config , stage = 'train' ):
@@ -259,7 +259,8 @@ def train_valid(self):
259259 ckpt_prefix = self .config ['training' ].get ('ckpt_prefix' , None )
260260 if (ckpt_prefix is None ):
261261 ckpt_prefix = ckpt_dir .split ('/' )[- 1 ]
262- iter_start = self .config ['training' ]['iter_start' ]
262+ # iter_start = self.config['training']['iter_start']
263+ iter_start = 0
263264 iter_max = self .config ['training' ]['iter_max' ]
264265 iter_valid = self .config ['training' ]['iter_valid' ]
265266 iter_save = self .config ['training' ].get ('iter_save' , None )
@@ -274,21 +275,32 @@ def train_valid(self):
274275 self .max_val_dice = 0.0
275276 self .max_val_it = 0
276277 self .best_model_wts = None
277- self .checkpoint = None
278- if (iter_start > 0 ):
279- checkpoint_file = "{0:}/{1:}_{2:}.pt" .format (ckpt_dir , ckpt_prefix , iter_start )
280- self .checkpoint = torch .load (checkpoint_file , map_location = self .device )
281- # assert(self.checkpoint['iteration'] == iter_start)
282- if (len (device_ids ) > 1 ):
283- self .net .module .load_state_dict (self .checkpoint ['model_state_dict' ])
278+ checkpoint = None
279+ # initialize the network with pre-trained weights
280+ ckpt_init_name = self .config ['training' ].get ('ckpt_init_name' , None )
281+ ckpt_init_mode = self .config ['training' ].get ('ckpt_init_mode' , 0 )
282+ ckpt_for_optm = None
283+ if (ckpt_init_name is not None ):
284+ checkpoint = torch .load (ckpt_dir + "/" + ckpt_init_name , map_location = self .device )
285+ pretrained_dict = checkpoint ['model_state_dict' ]
286+ model_dict = self .net .module .state_dict () if (len (device_ids ) > 1 ) else self .net .state_dict ()
287+ pretrained_dict = {k : v for k , v in pretrained_dict .items () if \
288+ k in model_dict and tensor_shape_match (pretrained_dict [k ], model_dict [k ])}
289+ logging .info ("Initializing the following parameters with pre-trained model" )
290+ for k in pretrained_dict :
291+ logging .info (k )
292+ if (len (device_ids ) > 1 ):
293+ self .net .module .load_state_dict (pretrained_dict , strict = False )
284294 else :
285- self .net .load_state_dict (self .checkpoint ['model_state_dict' ])
286- self .max_val_dice = self .checkpoint .get ('valid_pred' , 0 )
287- # self.max_val_it = self.checkpoint['iteration']
288- self .max_val_it = iter_start
289- self .best_model_wts = self .checkpoint ['model_state_dict' ]
290-
291- self .create_optimizer (self .get_parameters_to_update ())
295+ self .net .load_state_dict (pretrained_dict , strict = False )
296+
297+ if (ckpt_init_mode > 0 ): # Load other information
298+ self .max_val_dice = checkpoint .get ('valid_pred' , 0 )
299+ iter_start = checkpoint ['iteration' ] - 1
300+ self .max_val_it = iter_start
301+ self .best_model_wts = checkpoint ['model_state_dict' ]
302+ ckpt_for_optm = checkpoint
303+ self .create_optimizer (self .get_parameters_to_update (), ckpt_for_optm )
292304 self .create_loss_calculator ()
293305
294306 self .trainIter = iter (self .train_loader )
0 commit comments