Skip to content

Commit fee0a47

Browse files
committed
update ssl and wsl
calculate mean dice for foreground classes allow loading pre-trained models
1 parent 8c15798 commit fee0a47

File tree

13 files changed

+86
-58
lines changed

13 files changed

+86
-58
lines changed

pymic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from __future__ import absolute_import
2-
from . import *
2+
__version__ = "0.3.1"

pymic/io/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
from __future__ import absolute_import
2-
from . import *
2+
from pymic.io.image_read_write import *
3+
from pymic.io.nifty_dataset import *
4+
from pymic.io.h5_dataset import *

pymic/net/net2d/unet2d.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ class Decoder(nn.Module):
131131
:param class_num: (int) The class number for segmentation task.
132132
:param bilinear: (bool) Using bilinear for up-sampling or not.
133133
If False, deconvolution will be used for up-sampling.
134-
:param multiscale_pred: (bool) Get multi-scale prediction.
135134
"""
136135
def __init__(self, params):
137136
super(Decoder, self).__init__()
@@ -140,8 +139,7 @@ def __init__(self, params):
140139
self.ft_chns = self.params['feature_chns']
141140
self.dropout = self.params['dropout']
142141
self.n_class = self.params['class_num']
143-
self.bilinear = self.params.get('bilinear', True)
144-
self.mul_pred = self.params.get('multiscale_pred', False)
142+
self.bilinear = self.params['bilinear']
145143

146144
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
147145

@@ -151,10 +149,6 @@ def __init__(self, params):
151149
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear)
152150
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear)
153151
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1)
154-
if(self.mul_pred):
155-
self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1)
156-
self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1)
157-
self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1)
158152

159153
def forward(self, x):
160154
if(len(self.ft_chns) == 5):
@@ -169,11 +163,6 @@ def forward(self, x):
169163
x_d1 = self.up3(x_d2, x1)
170164
x_d0 = self.up4(x_d1, x0)
171165
output = self.out_conv(x_d0)
172-
if(self.mul_pred):
173-
output1 = self.out_conv1(x_d1)
174-
output2 = self.out_conv2(x_d2)
175-
output3 = self.out_conv3(x_d3)
176-
output = [output, output1, output2, output3]
177166
return output
178167

179168
class UNet2D(nn.Module):

pymic/net_run/agent_abstract.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def worker_init_fn(worker_id):
276276
self.test_loader = torch.utils.data.DataLoader(self.test_set,
277277
batch_size = bn_test, shuffle=False, num_workers= bn_test)
278278

279-
def create_optimizer(self, params):
279+
def create_optimizer(self, params, checkpoint = None):
280280
"""
281281
Create optimizer based on configuration.
282282
@@ -288,9 +288,9 @@ def create_optimizer(self, params):
288288
self.optimizer = get_optimizer(opt_params['optimizer'],
289289
params, opt_params)
290290
last_iter = -1
291-
if(self.checkpoint is not None):
292-
self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict'])
293-
last_iter = self.checkpoint['iteration'] - 1
291+
if(checkpoint is not None):
292+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
293+
last_iter = checkpoint['iteration'] - 1
294294
if(self.scheduler is None):
295295
opt_params["last_iter"] = last_iter
296296
self.scheduler = get_lr_scheduler(self.optimizer, opt_params)

pymic/net_run/agent_seg.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pymic.transform.trans_dict import TransformDict
3030
from pymic.util.post_process import PostProcessDict
3131
from pymic.util.image_process import convert_label
32-
from pymic.util.general import mixup
32+
from pymic.util.general import mixup, tensor_shape_match
3333

3434
class SegmentationAgent(NetRunAgent):
3535
def __init__(self, config, stage = 'train'):
@@ -259,7 +259,8 @@ def train_valid(self):
259259
ckpt_prefix = self.config['training'].get('ckpt_prefix', None)
260260
if(ckpt_prefix is None):
261261
ckpt_prefix = ckpt_dir.split('/')[-1]
262-
iter_start = self.config['training']['iter_start']
262+
# iter_start = self.config['training']['iter_start']
263+
iter_start = 0
263264
iter_max = self.config['training']['iter_max']
264265
iter_valid = self.config['training']['iter_valid']
265266
iter_save = self.config['training'].get('iter_save', None)
@@ -274,21 +275,32 @@ def train_valid(self):
274275
self.max_val_dice = 0.0
275276
self.max_val_it = 0
276277
self.best_model_wts = None
277-
self.checkpoint = None
278-
if(iter_start > 0):
279-
checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start)
280-
self.checkpoint = torch.load(checkpoint_file, map_location = self.device)
281-
# assert(self.checkpoint['iteration'] == iter_start)
282-
if(len(device_ids) > 1):
283-
self.net.module.load_state_dict(self.checkpoint['model_state_dict'])
278+
checkpoint = None
279+
# initialize the network with pre-trained weights
280+
ckpt_init_name = self.config['training'].get('ckpt_init_name', None)
281+
ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0)
282+
ckpt_for_optm = None
283+
if(ckpt_init_name is not None):
284+
checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device)
285+
pretrained_dict = checkpoint['model_state_dict']
286+
model_dict = self.net.module.state_dict() if (len(device_ids) > 1) else self.net.state_dict()
287+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if \
288+
k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])}
289+
logging.info("Initializing the following parameters with pre-trained model")
290+
for k in pretrained_dict:
291+
logging.info(k)
292+
if (len(device_ids) > 1):
293+
self.net.module.load_state_dict(pretrained_dict, strict = False)
284294
else:
285-
self.net.load_state_dict(self.checkpoint['model_state_dict'])
286-
self.max_val_dice = self.checkpoint.get('valid_pred', 0)
287-
# self.max_val_it = self.checkpoint['iteration']
288-
self.max_val_it = iter_start
289-
self.best_model_wts = self.checkpoint['model_state_dict']
290-
291-
self.create_optimizer(self.get_parameters_to_update())
295+
self.net.load_state_dict(pretrained_dict, strict = False)
296+
297+
if(ckpt_init_mode > 0): # Load other information
298+
self.max_val_dice = checkpoint.get('valid_pred', 0)
299+
iter_start = checkpoint['iteration'] - 1
300+
self.max_val_it = iter_start
301+
self.best_model_wts = checkpoint['model_state_dict']
302+
ckpt_for_optm = checkpoint
303+
self.create_optimizer(self.get_parameters_to_update(), ckpt_for_optm)
292304
self.create_loss_calculator()
293305

294306
self.trainIter = iter(self.train_loader)

pymic/net_run_nll/nll_co_teaching.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def forward(self, x):
3232
if(self.training):
3333
return out1, out2
3434
else:
35-
return (out1 + out2) / 3
35+
return (out1 + out2) / 2
3636

3737
class NLLCoTeaching(SegmentationAgent):
3838
"""
@@ -144,13 +144,13 @@ def training(self):
144144
train_avg_loss1 = train_loss1 / iter_valid
145145
train_avg_loss2 = train_loss2 / iter_valid
146146
train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
147-
train_avg_dice = train_cls_dice.mean()
147+
train_avg_dice = train_cls_dice[1:].mean()
148148

149149
train_scalers = {'loss': (train_avg_loss1 + train_avg_loss2) / 2,
150150
'loss1':train_avg_loss1, 'loss2': train_avg_loss2,
151151
'loss_no_select1':train_avg_loss_no_select1,
152152
'loss_no_select2':train_avg_loss_no_select2,
153-
'select_ratio':remb_ratio, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
153+
'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice}
154154
return train_scalers
155155

156156
def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
@@ -159,7 +159,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
159159
loss_no_select_scalar = {'net1':train_scalars['loss_no_select1'],
160160
'net2':train_scalars['loss_no_select2']}
161161

162-
dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']}
162+
dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']}
163163
self.summ_writer.add_scalars('loss', loss_scalar, glob_it)
164164
self.summ_writer.add_scalars('loss_no_select', loss_no_select_scalar, glob_it)
165165
self.summ_writer.add_scalars('select_ratio', {'select_ratio':train_scalars['select_ratio']}, glob_it)
@@ -171,9 +171,9 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
171171
'valid':valid_scalars['class_dice'][c]}
172172
self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it)
173173

174-
logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format(
175-
train_scalars['loss'], train_scalars['avg_dice']) + "[" + \
174+
logging.info('train loss {0:.4f}, avg foreground dice {1:.4f} '.format(
175+
train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \
176176
' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]")
177-
logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format(
178-
valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \
177+
logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format(
178+
valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \
179179
' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]")

pymic/net_run_nll/nll_dast.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66
import torch.nn as nn
77
import torchvision.transforms as transforms
8-
from torch.optim import lr_scheduler
98
from pymic.io.nifty_dataset import NiftyDataset
109
from pymic.loss.seg.util import get_soft_label
1110
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
@@ -257,11 +256,11 @@ def training(self):
257256
train_avg_loss_sup = train_loss_sup / iter_valid
258257
train_avg_loss_reg = train_loss_reg / iter_valid
259258
train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
260-
train_avg_dice = train_cls_dice.mean()
259+
train_avg_dice = train_cls_dice[1:].mean()
261260

262261
train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup,
263262
'loss_reg':train_avg_loss_reg, 'regular_w':w_dbc,
264-
'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
263+
'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice}
265264
return train_scalers
266265

267266
def train_valid(self):

pymic/net_run_nll/nll_main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@ def main():
2828
log_dir = config['training']['ckpt_save_dir']
2929
if(not os.path.exists(log_dir)):
3030
os.mkdir(log_dir)
31-
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
32-
format='%(message)s')
31+
if sys.version.startswith("3.9"):
32+
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
33+
format='%(message)s', force=True) # for python 3.9
34+
else:
35+
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
36+
format='%(message)s') # for python 3.6
3337
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
3438
logging_config(config)
3539
nll_method = config['noisy_label_learning']['nll_method']

pymic/net_run_nll/nll_trinet.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,13 @@ def training(self):
140140
train_avg_loss1 = train_loss1 / iter_valid
141141
train_avg_loss2 = train_loss2 / iter_valid
142142
train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
143-
train_avg_dice = train_cls_dice.mean()
143+
train_avg_dice = train_cls_dice[1:].mean()
144144

145145
train_scalers = {'loss': (train_avg_loss1 + train_avg_loss2) / 2,
146146
'loss1':train_avg_loss1, 'loss2': train_avg_loss2,
147147
'loss_no_select1':train_avg_loss_no_select1,
148148
'loss_no_select2':train_avg_loss_no_select2,
149-
'select_ratio':remb_ratio, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
149+
'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice}
150150
return train_scalers
151151

152152
def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
@@ -155,7 +155,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
155155
loss_no_select_scalar = {'net1':train_scalars['loss_no_select1'],
156156
'net2':train_scalars['loss_no_select2']}
157157

158-
dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']}
158+
dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']}
159159
self.summ_writer.add_scalars('loss', loss_scalar, glob_it)
160160
self.summ_writer.add_scalars('loss_no_select', loss_no_select_scalar, glob_it)
161161
self.summ_writer.add_scalars('select_ratio', {'select_ratio':train_scalars['select_ratio']}, glob_it)
@@ -167,9 +167,9 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
167167
'valid':valid_scalars['class_dice'][c]}
168168
self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it)
169169

170-
logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format(
171-
train_scalars['loss'], train_scalars['avg_dice']) + "[" + \
170+
logging.info('train loss {0:.4f}, avg foregournd dice {1:.4f} '.format(
171+
train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \
172172
' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]")
173-
logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format(
174-
valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \
173+
logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format(
174+
valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \
175175
' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]")

pymic/net_run_ssl/ssl_cps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
166166
self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it)
167167

168168
logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format(
169-
train_scalars['loss'], train_scalars['avg_dice']) + "[" + \
169+
train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \
170170
' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]")
171171
logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format(
172-
valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \
172+
valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \
173173
' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]")

0 commit comments

Comments
 (0)