1+ # -*- coding: utf-8 -*-
2+ from __future__ import print_function , division
3+ import copy
4+ import logging
5+ import time
6+ import logging
7+ import numpy as np
8+ import random
9+ import torch
10+ import torch .nn as nn
11+ import torchvision .transforms as transforms
12+ from datetime import datetime
13+ from random import random
14+ from torch .optim import lr_scheduler
15+ from tensorboardX import SummaryWriter
16+ from pymic .io .nifty_dataset import NiftyDataset
17+ from pymic .loss .seg .util import get_soft_label
18+ from pymic .loss .seg .util import reshape_prediction_and_ground_truth
19+ from pymic .loss .seg .util import get_classwise_dice
20+ from pymic .net_run .infer_func import Inferer
21+ from pymic .net_run .agent_seg import SegmentationAgent
22+ from pymic .transform .trans_dict import TransformDict
23+ from pymic .loss .seg .mse import MAELoss , MSELoss
24+
25+ RegressionLossDict = {
26+ 'MAELoss' : MAELoss ,
27+ 'MSELoss' : MSELoss
28+ }
29+
30+ class SelfSLSegAgent (SegmentationAgent ):
31+ """
32+ Abstract class for self-supervised segmentation.
33+
34+ :param config: (dict) A dictionary containing the configuration.
35+ :param stage: (str) One of the stage in `train` (default), `inference` or `test`.
36+
37+ .. note::
38+
39+ In the configuration dictionary, in addition to the four sections (`dataset`,
40+ `network`, `training` and `inference`) used in fully supervised learning, an
41+ extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details.
42+ """
43+ def __init__ (self , config , stage = 'train' ):
44+ super (SelfSLSegAgent , self ).__init__ (config , stage )
45+ self .transform_dict = TransformDict
46+
47+ def create_loss_calculator (self ):
48+ if (self .loss_dict is None ):
49+ self .loss_dict = RegressionLossDict
50+ loss_name = self .config ['training' ]['loss_type' ]
51+ if isinstance (loss_name , (list , tuple )):
52+ raise ValueError ("Undefined loss function {0:}" .format (loss_name ))
53+ elif (loss_name not in self .loss_dict ):
54+ raise ValueError ("Undefined loss function {0:}" .format (loss_name ))
55+ else :
56+ loss_param = self .config ['training' ]
57+ loss_param ['loss_softmax' ] = False
58+ base_loss = self .loss_dict [loss_name ](self .config ['training' ])
59+ if (self .config ['training' ].get ('deep_supervise' , False )):
60+ raise ValueError ("Deep supervised loss not implemented for self-supervised learning" )
61+ # weight = self.config['training'].get('deep_supervise_weight', None)
62+ # mode = self.config['training'].get('deep_supervise_mode', 2)
63+ # params = {'deep_supervise_weight': weight,
64+ # 'deep_supervise_mode': mode,
65+ # 'base_loss':base_loss}
66+ # self.loss_calculator = DeepSuperviseLoss(params)
67+ else :
68+ self .loss_calculator = base_loss
69+
70+ def training (self ):
71+ iter_valid = self .config ['training' ]['iter_valid' ]
72+ train_loss = 0
73+ self .net .train ()
74+ for it in range (iter_valid ):
75+ try :
76+ data = next (self .trainIter )
77+ except StopIteration :
78+ self .trainIter = iter (self .train_loader )
79+ data = next (self .trainIter )
80+ # get the inputs
81+ inputs = self .convert_tensor_type (data ['image' ])
82+ label = self .convert_tensor_type (data ['label' ])
83+
84+ # for debug
85+ # from pymic.io.image_read_write import save_nd_array_as_image
86+ # for i in range(inputs.shape[0]):
87+ # image_i = inputs[i][0]
88+ # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i)
89+ # save_nd_array_as_image(image_i, image_name, reference_name = None)
90+ # return
91+
92+ inputs , label = inputs .to (self .device ), label .to (self .device )
93+
94+ # zero the parameter gradients
95+ self .optimizer .zero_grad ()
96+
97+ # forward + backward + optimize
98+ outputs = self .net (inputs )
99+ outputs = nn .Sigmoid ()(outputs )
100+ loss = self .get_loss_value (data , outputs , label )
101+ loss .backward ()
102+ self .optimizer .step ()
103+ train_loss = train_loss + loss .item ()
104+ # get dice evaluation for each class
105+ if (isinstance (outputs , tuple ) or isinstance (outputs , list )):
106+ outputs = outputs [0 ]
107+
108+ train_avg_loss = train_loss / iter_valid
109+ train_scalers = {'loss' : train_avg_loss }
110+ return train_scalers
111+
112+ def validation (self ):
113+ if (self .inferer is None ):
114+ infer_cfg = self .config ['testing' ]
115+ self .inferer = Inferer (infer_cfg )
116+
117+ valid_loss_list = []
118+ validIter = iter (self .valid_loader )
119+ with torch .no_grad ():
120+ self .net .eval ()
121+ for data in validIter :
122+ inputs = self .convert_tensor_type (data ['image' ])
123+ label = self .convert_tensor_type (data ['label' ])
124+ inputs , label = inputs .to (self .device ), label .to (self .device )
125+ outputs = self .inferer .run (self .net , inputs )
126+ outputs = nn .Sigmoid ()(outputs )
127+ # The tensors are on CPU when calculating loss for validation data
128+ loss = self .get_loss_value (data , outputs , label )
129+ valid_loss_list .append (loss .item ())
130+
131+ valid_avg_loss = np .asarray (valid_loss_list ).mean ()
132+ valid_scalers = {'loss' : valid_avg_loss }
133+ return valid_scalers
134+
135+ def write_scalars (self , train_scalars , valid_scalars , lr_value , glob_it ):
136+ loss_scalar = {'train' :train_scalars ['loss' ],
137+ 'valid' :valid_scalars ['loss' ]}
138+ self .summ_writer .add_scalars ('loss' , loss_scalar , glob_it )
139+ self .summ_writer .add_scalars ('lr' , {"lr" : lr_value }, glob_it )
140+ logging .info ('train loss {0:.4f}' .format (train_scalars ['loss' ]))
141+ logging .info ('valid loss {0:.4f}' .format (valid_scalars ['loss' ]))
142+
143+ def train_valid (self ):
144+ device_ids = self .config ['training' ]['gpus' ]
145+ if (len (device_ids ) > 1 ):
146+ self .device = torch .device ("cuda:0" )
147+ self .net = nn .DataParallel (self .net , device_ids = device_ids )
148+ else :
149+ self .device = torch .device ("cuda:{0:}" .format (device_ids [0 ]))
150+ self .net .to (self .device )
151+ ckpt_dir = self .config ['training' ]['ckpt_save_dir' ]
152+ ckpt_prefix = self .config ['training' ].get ('ckpt_prefix' , None )
153+ if (ckpt_prefix is None ):
154+ ckpt_prefix = ckpt_dir .split ('/' )[- 1 ]
155+ iter_start = self .config ['training' ]['iter_start' ]
156+ iter_max = self .config ['training' ]['iter_max' ]
157+ iter_valid = self .config ['training' ]['iter_valid' ]
158+ iter_save = self .config ['training' ].get ('iter_save' , None )
159+ early_stop_it = self .config ['training' ].get ('early_stop_patience' , None )
160+ if (iter_save is None ):
161+ iter_save_list = [iter_max ]
162+ elif (isinstance (iter_save , (tuple , list ))):
163+ iter_save_list = iter_save
164+ else :
165+ iter_save_list = range (0 , iter_max + 1 , iter_save )
166+
167+ self .min_val_loss = 10000.0
168+ self .max_val_it = 0
169+ self .best_model_wts = None
170+ self .checkpoint = None
171+ if (iter_start > 0 ):
172+ checkpoint_file = "{0:}/{1:}_{2:}.pt" .format (ckpt_dir , ckpt_prefix , iter_start )
173+ self .checkpoint = torch .load (checkpoint_file , map_location = self .device )
174+ # assert(self.checkpoint['iteration'] == iter_start)
175+ if (len (device_ids ) > 1 ):
176+ self .net .module .load_state_dict (self .checkpoint ['model_state_dict' ])
177+ else :
178+ self .net .load_state_dict (self .checkpoint ['model_state_dict' ])
179+ self .min_val_loss = self .checkpoint .get ('valid_loss' , 10000 )
180+ # self.max_val_it = self.checkpoint['iteration']
181+ self .max_val_it = iter_start
182+ self .best_model_wts = self .checkpoint ['model_state_dict' ]
183+
184+ self .create_optimizer (self .get_parameters_to_update ())
185+ self .create_loss_calculator ()
186+
187+ self .trainIter = iter (self .train_loader )
188+
189+ logging .info ("{0:} training start" .format (str (datetime .now ())[:- 7 ]))
190+ self .summ_writer = SummaryWriter (self .config ['training' ]['ckpt_save_dir' ])
191+ self .glob_it = iter_start
192+ for it in range (iter_start , iter_max , iter_valid ):
193+ lr_value = self .optimizer .param_groups [0 ]['lr' ]
194+ t0 = time .time ()
195+ train_scalars = self .training ()
196+ t1 = time .time ()
197+ valid_scalars = self .validation ()
198+ t2 = time .time ()
199+ if (isinstance (self .scheduler , lr_scheduler .ReduceLROnPlateau )):
200+ self .scheduler .step (- valid_scalars ['loss' ])
201+ else :
202+ self .scheduler .step ()
203+
204+ self .glob_it = it + iter_valid
205+ logging .info ("\n {0:} it {1:}" .format (str (datetime .now ())[:- 7 ], self .glob_it ))
206+ logging .info ('learning rate {0:}' .format (lr_value ))
207+ logging .info ("training/validation time: {0:.2f}s/{1:.2f}s" .format (t1 - t0 , t2 - t1 ))
208+ self .write_scalars (train_scalars , valid_scalars , lr_value , self .glob_it )
209+ if (valid_scalars ['loss' ] < self .min_val_loss ):
210+ self .min_val_loss = valid_scalars ['loss' ]
211+ self .max_val_it = self .glob_it
212+ if (len (device_ids ) > 1 ):
213+ self .best_model_wts = copy .deepcopy (self .net .module .state_dict ())
214+ else :
215+ self .best_model_wts = copy .deepcopy (self .net .state_dict ())
216+
217+ stop_now = True if (early_stop_it is not None and \
218+ self .glob_it - self .max_val_it > early_stop_it ) else False
219+ if ((self .glob_it in iter_save_list ) or stop_now ):
220+ save_dict = {'iteration' : self .glob_it ,
221+ 'valid_loss' : valid_scalars ['loss' ],
222+ 'model_state_dict' : self .net .module .state_dict () \
223+ if len (device_ids ) > 1 else self .net .state_dict (),
224+ 'optimizer_state_dict' : self .optimizer .state_dict ()}
225+ save_name = "{0:}/{1:}_{2:}.pt" .format (ckpt_dir , ckpt_prefix , self .glob_it )
226+ torch .save (save_dict , save_name )
227+ txt_file = open ("{0:}/{1:}_latest.txt" .format (ckpt_dir , ckpt_prefix ), 'wt' )
228+ txt_file .write (str (self .glob_it ))
229+ txt_file .close ()
230+ if (stop_now ):
231+ logging .info ("The training is early stopped" )
232+ break
233+ # save the best performing checkpoint
234+ save_dict = {'iteration' : self .max_val_it ,
235+ 'valid_loss' : self .min_val_loss ,
236+ 'model_state_dict' : self .best_model_wts ,
237+ 'optimizer_state_dict' : self .optimizer .state_dict ()}
238+ save_name = "{0:}/{1:}_{2:}.pt" .format (ckpt_dir , ckpt_prefix , self .max_val_it )
239+ torch .save (save_dict , save_name )
240+ txt_file = open ("{0:}/{1:}_best.txt" .format (ckpt_dir , ckpt_prefix ), 'wt' )
241+ txt_file .write (str (self .max_val_it ))
242+ txt_file .close ()
243+ logging .info ('The best performing iter is {0:}, valid loss {1:}' .format (\
244+ self .max_val_it , self .min_val_loss ))
245+ self .summ_writer .close ()
0 commit comments