Skip to content

Commit d4d51dc

Browse files
committed
add self supervised learning
add self supervised learning
1 parent 9ab3fca commit d4d51dc

File tree

5 files changed

+337
-4
lines changed

5 files changed

+337
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ BibTeX entry:
2323

2424
# Features
2525
PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions:
26-
* Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning.
26+
* Support for annotation-efficient image segmentation, especially for semi-supervised, self-supervised, weakly-supervised and noisy-label learning.
2727
* User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC.
2828
* Easy-to-use I/O interface to read and write different 2D and 3D images.
2929
* Various data pre-processing/transformation methods before sending a tensor into a network.

pymic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from __future__ import absolute_import
2-
__version__ = "0.3.1"
2+
__version__ = "0.3.1.1"
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
import copy
4+
import logging
5+
import time
6+
import logging
7+
import numpy as np
8+
import random
9+
import torch
10+
import torch.nn as nn
11+
import torchvision.transforms as transforms
12+
from datetime import datetime
13+
from random import random
14+
from torch.optim import lr_scheduler
15+
from tensorboardX import SummaryWriter
16+
from pymic.io.nifty_dataset import NiftyDataset
17+
from pymic.loss.seg.util import get_soft_label
18+
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
19+
from pymic.loss.seg.util import get_classwise_dice
20+
from pymic.net_run.infer_func import Inferer
21+
from pymic.net_run.agent_seg import SegmentationAgent
22+
from pymic.transform.trans_dict import TransformDict
23+
from pymic.loss.seg.mse import MAELoss, MSELoss
24+
25+
RegressionLossDict = {
26+
'MAELoss': MAELoss,
27+
'MSELoss': MSELoss
28+
}
29+
30+
class SelfSLSegAgent(SegmentationAgent):
31+
"""
32+
Abstract class for self-supervised segmentation.
33+
34+
:param config: (dict) A dictionary containing the configuration.
35+
:param stage: (str) One of the stage in `train` (default), `inference` or `test`.
36+
37+
.. note::
38+
39+
In the configuration dictionary, in addition to the four sections (`dataset`,
40+
`network`, `training` and `inference`) used in fully supervised learning, an
41+
extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details.
42+
"""
43+
def __init__(self, config, stage = 'train'):
44+
super(SelfSLSegAgent, self).__init__(config, stage)
45+
self.transform_dict = TransformDict
46+
47+
def create_loss_calculator(self):
48+
if(self.loss_dict is None):
49+
self.loss_dict = RegressionLossDict
50+
loss_name = self.config['training']['loss_type']
51+
if isinstance(loss_name, (list, tuple)):
52+
raise ValueError("Undefined loss function {0:}".format(loss_name))
53+
elif (loss_name not in self.loss_dict):
54+
raise ValueError("Undefined loss function {0:}".format(loss_name))
55+
else:
56+
loss_param = self.config['training']
57+
loss_param['loss_softmax'] = False
58+
base_loss = self.loss_dict[loss_name](self.config['training'])
59+
if(self.config['training'].get('deep_supervise', False)):
60+
raise ValueError("Deep supervised loss not implemented for self-supervised learning")
61+
# weight = self.config['training'].get('deep_supervise_weight', None)
62+
# mode = self.config['training'].get('deep_supervise_mode', 2)
63+
# params = {'deep_supervise_weight': weight,
64+
# 'deep_supervise_mode': mode,
65+
# 'base_loss':base_loss}
66+
# self.loss_calculator = DeepSuperviseLoss(params)
67+
else:
68+
self.loss_calculator = base_loss
69+
70+
def training(self):
71+
iter_valid = self.config['training']['iter_valid']
72+
train_loss = 0
73+
self.net.train()
74+
for it in range(iter_valid):
75+
try:
76+
data = next(self.trainIter)
77+
except StopIteration:
78+
self.trainIter = iter(self.train_loader)
79+
data = next(self.trainIter)
80+
# get the inputs
81+
inputs = self.convert_tensor_type(data['image'])
82+
label = self.convert_tensor_type(data['label'])
83+
84+
# for debug
85+
# from pymic.io.image_read_write import save_nd_array_as_image
86+
# for i in range(inputs.shape[0]):
87+
# image_i = inputs[i][0]
88+
# image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i)
89+
# save_nd_array_as_image(image_i, image_name, reference_name = None)
90+
# return
91+
92+
inputs, label = inputs.to(self.device), label.to(self.device)
93+
94+
# zero the parameter gradients
95+
self.optimizer.zero_grad()
96+
97+
# forward + backward + optimize
98+
outputs = self.net(inputs)
99+
outputs = nn.Sigmoid()(outputs)
100+
loss = self.get_loss_value(data, outputs, label)
101+
loss.backward()
102+
self.optimizer.step()
103+
train_loss = train_loss + loss.item()
104+
# get dice evaluation for each class
105+
if(isinstance(outputs, tuple) or isinstance(outputs, list)):
106+
outputs = outputs[0]
107+
108+
train_avg_loss = train_loss / iter_valid
109+
train_scalers = {'loss': train_avg_loss}
110+
return train_scalers
111+
112+
def validation(self):
113+
if(self.inferer is None):
114+
infer_cfg = self.config['testing']
115+
self.inferer = Inferer(infer_cfg)
116+
117+
valid_loss_list = []
118+
validIter = iter(self.valid_loader)
119+
with torch.no_grad():
120+
self.net.eval()
121+
for data in validIter:
122+
inputs = self.convert_tensor_type(data['image'])
123+
label = self.convert_tensor_type(data['label'])
124+
inputs, label = inputs.to(self.device), label.to(self.device)
125+
outputs = self.inferer.run(self.net, inputs)
126+
outputs = nn.Sigmoid()(outputs)
127+
# The tensors are on CPU when calculating loss for validation data
128+
loss = self.get_loss_value(data, outputs, label)
129+
valid_loss_list.append(loss.item())
130+
131+
valid_avg_loss = np.asarray(valid_loss_list).mean()
132+
valid_scalers = {'loss': valid_avg_loss}
133+
return valid_scalers
134+
135+
def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
136+
loss_scalar ={'train':train_scalars['loss'],
137+
'valid':valid_scalars['loss']}
138+
self.summ_writer.add_scalars('loss', loss_scalar, glob_it)
139+
self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it)
140+
logging.info('train loss {0:.4f}'.format(train_scalars['loss']))
141+
logging.info('valid loss {0:.4f}'.format(valid_scalars['loss']))
142+
143+
def train_valid(self):
144+
device_ids = self.config['training']['gpus']
145+
if(len(device_ids) > 1):
146+
self.device = torch.device("cuda:0")
147+
self.net = nn.DataParallel(self.net, device_ids = device_ids)
148+
else:
149+
self.device = torch.device("cuda:{0:}".format(device_ids[0]))
150+
self.net.to(self.device)
151+
ckpt_dir = self.config['training']['ckpt_save_dir']
152+
ckpt_prefix = self.config['training'].get('ckpt_prefix', None)
153+
if(ckpt_prefix is None):
154+
ckpt_prefix = ckpt_dir.split('/')[-1]
155+
iter_start = self.config['training']['iter_start']
156+
iter_max = self.config['training']['iter_max']
157+
iter_valid = self.config['training']['iter_valid']
158+
iter_save = self.config['training'].get('iter_save', None)
159+
early_stop_it = self.config['training'].get('early_stop_patience', None)
160+
if(iter_save is None):
161+
iter_save_list = [iter_max]
162+
elif(isinstance(iter_save, (tuple, list))):
163+
iter_save_list = iter_save
164+
else:
165+
iter_save_list = range(0, iter_max + 1, iter_save)
166+
167+
self.min_val_loss = 10000.0
168+
self.max_val_it = 0
169+
self.best_model_wts = None
170+
self.checkpoint = None
171+
if(iter_start > 0):
172+
checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start)
173+
self.checkpoint = torch.load(checkpoint_file, map_location = self.device)
174+
# assert(self.checkpoint['iteration'] == iter_start)
175+
if(len(device_ids) > 1):
176+
self.net.module.load_state_dict(self.checkpoint['model_state_dict'])
177+
else:
178+
self.net.load_state_dict(self.checkpoint['model_state_dict'])
179+
self.min_val_loss = self.checkpoint.get('valid_loss', 10000)
180+
# self.max_val_it = self.checkpoint['iteration']
181+
self.max_val_it = iter_start
182+
self.best_model_wts = self.checkpoint['model_state_dict']
183+
184+
self.create_optimizer(self.get_parameters_to_update())
185+
self.create_loss_calculator()
186+
187+
self.trainIter = iter(self.train_loader)
188+
189+
logging.info("{0:} training start".format(str(datetime.now())[:-7]))
190+
self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir'])
191+
self.glob_it = iter_start
192+
for it in range(iter_start, iter_max, iter_valid):
193+
lr_value = self.optimizer.param_groups[0]['lr']
194+
t0 = time.time()
195+
train_scalars = self.training()
196+
t1 = time.time()
197+
valid_scalars = self.validation()
198+
t2 = time.time()
199+
if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
200+
self.scheduler.step(-valid_scalars['loss'])
201+
else:
202+
self.scheduler.step()
203+
204+
self.glob_it = it + iter_valid
205+
logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it))
206+
logging.info('learning rate {0:}'.format(lr_value))
207+
logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1))
208+
self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it)
209+
if(valid_scalars['loss'] < self.min_val_loss):
210+
self.min_val_loss = valid_scalars['loss']
211+
self.max_val_it = self.glob_it
212+
if(len(device_ids) > 1):
213+
self.best_model_wts = copy.deepcopy(self.net.module.state_dict())
214+
else:
215+
self.best_model_wts = copy.deepcopy(self.net.state_dict())
216+
217+
stop_now = True if(early_stop_it is not None and \
218+
self.glob_it - self.max_val_it > early_stop_it) else False
219+
if ((self.glob_it in iter_save_list) or stop_now):
220+
save_dict = {'iteration': self.glob_it,
221+
'valid_loss': valid_scalars['loss'],
222+
'model_state_dict': self.net.module.state_dict() \
223+
if len(device_ids) > 1 else self.net.state_dict(),
224+
'optimizer_state_dict': self.optimizer.state_dict()}
225+
save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it)
226+
torch.save(save_dict, save_name)
227+
txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt')
228+
txt_file.write(str(self.glob_it))
229+
txt_file.close()
230+
if(stop_now):
231+
logging.info("The training is early stopped")
232+
break
233+
# save the best performing checkpoint
234+
save_dict = {'iteration': self.max_val_it,
235+
'valid_loss': self.min_val_loss,
236+
'model_state_dict': self.best_model_wts,
237+
'optimizer_state_dict': self.optimizer.state_dict()}
238+
save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it)
239+
torch.save(save_dict, save_name)
240+
txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt')
241+
txt_file.write(str(self.max_val_it))
242+
txt_file.close()
243+
logging.info('The best performing iter is {0:}, valid loss {1:}'.format(\
244+
self.max_val_it, self.min_val_loss))
245+
self.summ_writer.close()
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
2+
# -*- coding: utf-8 -*-
3+
from __future__ import print_function, division
4+
import logging
5+
import os
6+
import sys
7+
import shutil
8+
from pymic.util.parse_config import *
9+
from pymic.net_run_self_sl.self_sl_agent import SelfSLSegAgent
10+
11+
def model_genesis(stage, cfg_file):
12+
config = parse_config(cfg_file)
13+
transforms = ['RandomFlip', 'LocalShuffling', 'NonLinearTransform', 'InOutPainting']
14+
genesis_cfg = {
15+
'randomflip_flip_depth': True,
16+
'randomflip_flip_height': True,
17+
'randomflip_flip_width': True,
18+
'localshuffling_probability': 0.5,
19+
'nonLineartransform_probability': 0.9,
20+
'inoutpainting_probability': 0.9,
21+
'inpainting_probability': 0.2
22+
}
23+
config['dataset']['train_transform'].extend(transforms)
24+
config['dataset']['valid_transform'].extend(transforms)
25+
config['dataset'].update(genesis_cfg)
26+
27+
config = synchronize_config(config)
28+
log_dir = config['training']['ckpt_save_dir']
29+
if(not os.path.exists(log_dir)):
30+
os.mkdir(log_dir)
31+
if(stage == "train"):
32+
dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1]
33+
shutil.copy(cfg_file, log_dir + "/" + dst_cfg)
34+
if sys.version.startswith("3.9"):
35+
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
36+
format='%(message)s', force=True) # for python 3.9
37+
else:
38+
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
39+
format='%(message)s') # for python 3.6
40+
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
41+
logging_config(config)
42+
agent = SelfSLSegAgent(config, stage)
43+
agent.run()
44+
45+
def default_self_sl(stage, cfg_file):
46+
config = parse_config(cfg_file)
47+
config = synchronize_config(config)
48+
log_dir = config['training']['ckpt_save_dir']
49+
if(not os.path.exists(log_dir)):
50+
os.mkdir(log_dir)
51+
if(stage == "train"):
52+
dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1]
53+
shutil.copy(cfg_file, log_dir + "/" + dst_cfg)
54+
if sys.version.startswith("3.9"):
55+
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
56+
format='%(message)s', force=True) # for python 3.9
57+
else:
58+
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
59+
format='%(message)s') # for python 3.6
60+
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
61+
logging_config(config)
62+
agent = SelfSLSegAgent(config, stage)
63+
agent.run()
64+
65+
66+
if __name__ == "__main__":
67+
if(len(sys.argv) < 3):
68+
print('Number of arguments should be 3. e.g.')
69+
print(' pymic_self_sl train config.cfg')
70+
exit()
71+
stage = str(sys.argv[1])
72+
cfg_file = str(sys.argv[2])
73+
config = parse_config(cfg_file)
74+
method = "default"
75+
if 'self_supervised_learning' in config:
76+
method = config['self_supervised_learning'].get('self_sl_method', 'default')
77+
print("the self supervised method is ", method)
78+
if(method == "default"):
79+
default_self_sl(stage, cfg_file)
80+
elif(method == 'model_genesis'):
81+
model_genesis(stage, cfg_file)
82+
else:
83+
raise ValueError("The specified method {0:} is not implemented. ".format(method) + \
84+
"Consider to set `self_sl_method = default` and use customized" + \
85+
" transforms for self-supervised learning.")
86+
87+

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
setuptools.setup(
1313
name = 'PYMIC',
14-
version = "0.3.1",
14+
version = "0.3.1.1",
1515
author ='PyMIC Consortium',
1616
author_email = 'wguotai@gmail.com',
1717
description = description,
@@ -42,7 +42,8 @@
4242
entry_points = {
4343
'console_scripts': [
4444
'pymic_run = pymic.net_run.net_run:main',
45-
'pymic_ssl = pymic.net_run_ssl.ssl_main:main',
45+
'pymic_semi_sl = pymic.net_run_ssl.ssl_main:main',
46+
'pymic_self_sl = pymic.net_run_self_sl.self_sl_main:main',
4647
'pymic_wsl = pymic.net_run_wsl.wsl_main:main',
4748
'pymic_nll = pymic.net_run_nll.nll_main:main',
4849
'pymic_eval_cls = pymic.util.evaluation_cls:main',

0 commit comments

Comments
 (0)