Skip to content

Commit 3340cb0

Browse files
committed
reorganize code
add train.py and predict.py move semi_sup, self_sup, noisy_label and weak_sup to net_run
1 parent d4d51dc commit 3340cb0

34 files changed

+172
-286
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from __future__ import absolute_import
2+
from pymic.net_run.noisy_label.nll_co_teaching import NLLCoTeaching
3+
from pymic.net_run.noisy_label.nll_trinet import NLLTriNet
4+
from pymic.net_run.noisy_label.nll_dast import NLLDAST
5+
6+
NLLMethodDict = {'CoTeaching': NLLCoTeaching,
7+
"TriNet": NLLTriNet,
8+
"DAST": NLLDAST}
File renamed without changes.
File renamed without changes.
File renamed without changes.

pymic/net_run/net_run.py renamed to pymic/net_run/predict.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import os
55
import sys
6-
import shutil
6+
from datetime import datetime
77
from pymic.util.parse_config import *
88
from pymic.net_run.agent_cls import ClassificationAgent
99
from pymic.net_run.agent_seg import SegmentationAgent
@@ -12,34 +12,31 @@ def main():
1212
"""
1313
The main function for running a network for training or inference.
1414
"""
15-
if(len(sys.argv) < 3):
16-
print('Number of arguments should be 3. e.g.')
17-
print(' pymic_run train config.cfg')
15+
if(len(sys.argv) < 2):
16+
print('Number of arguments should be 2. e.g.')
17+
print(' pymic_test config.cfg')
1818
exit()
19-
stage = str(sys.argv[1])
20-
cfg_file = str(sys.argv[2])
19+
cfg_file = str(sys.argv[1])
2120
config = parse_config(cfg_file)
2221
config = synchronize_config(config)
23-
log_dir = config['training']['ckpt_save_dir']
22+
log_dir = config['testing']['output_dir']
2423
if(not os.path.exists(log_dir)):
2524
os.makedirs(log_dir, exist_ok=True)
26-
if(stage == "train"):
27-
dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1]
28-
shutil.copy(cfg_file, log_dir + "/" + dst_cfg)
25+
2926
if sys.version.startswith("3.9"):
30-
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
31-
format='%(message)s', force=True) # for python 3.9
27+
logging.basicConfig(filename=log_dir+"/log_test.txt",
28+
level=logging.INFO, format='%(message)s', force=True) # for python 3.9
3229
else:
33-
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
34-
format='%(message)s') # for python 3.6
30+
logging.basicConfig(filename=log_dir+"/log_test.txt",
31+
level=logging.INFO, format='%(message)s') # for python 3.6
3532
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
3633
logging_config(config)
3734
task = config['dataset']['task_type']
3835
assert task in ['cls', 'cls_nexcl', 'seg']
3936
if(task == 'cls' or task == 'cls_nexcl'):
40-
agent = ClassificationAgent(config, stage)
37+
agent = ClassificationAgent(config, 'test')
4138
else:
42-
agent = SegmentationAgent(config, stage)
39+
agent = SegmentationAgent(config, 'test')
4340
agent.run()
4441

4542
if __name__ == "__main__":

pymic/net_run/self_sup/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from __future__ import absolute_import
2+
from pymic.net_run.self_sup.self_sl_agent import SelfSLSegAgent
File renamed without changes.

pymic/net_run/semi_sup/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from __future__ import absolute_import
2+
from pymic.net_run.semi_sup.ssl_abstract import SSLSegAgent
3+
from pymic.net_run.semi_sup.ssl_em import SSLEntropyMinimization
4+
from pymic.net_run.semi_sup.ssl_mt import SSLMeanTeacher
5+
from pymic.net_run.semi_sup.ssl_uamt import SSLUncertaintyAwareMeanTeacher
6+
from pymic.net_run.semi_sup.ssl_cct import SSLCCT
7+
from pymic.net_run.semi_sup.ssl_cps import SSLCPS
8+
from pymic.net_run.semi_sup.ssl_urpc import SSLURPC
9+
10+
11+
SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization,
12+
'MeanTeacher': SSLMeanTeacher,
13+
'UAMT': SSLUncertaintyAwareMeanTeacher,
14+
'CCT': SSLCCT,
15+
'CPS': SSLCPS,
16+
'URPC': SSLURPC}
File renamed without changes.

0 commit comments

Comments
 (0)